In [None]:
import pandas as pd
import glob
import json
import re
import plotly.express as exp
import os
import numpy as np
from tqdm.autonotebook import tqdm

def jaccard(str1, str2): 
    a = set(str1.lower().split()) 
    b = set(str2.lower().split())
    c = a.intersection(b)
    return float(len(c)) / (len(a) + len(b) - len(c))

In [None]:
import transformers
import torch.nn as nn
import torch

In [None]:
def clean_text(txt):
    return re.sub('[^A-Za-z0-9]+', ' ', str(txt).lower()).strip()


def totally_clean_text(txt):
    txt = clean_text(txt)
    txt = re.sub(' +', ' ', txt)
    return txt



In [None]:
train = pd.read_csv("../input/coleridgeinitiative-show-us-the-data/train.csv")

In [None]:
base_path = "../input/coleridgeinitiative-show-us-the-data/train"
train_files = train.Id.to_numpy()

In [None]:
def get_dataset_phrase(entire):
    words = entire.split()
    tmp = []
    prop = []
    invalid_w = ['the','using']
    for i in range(len(words)-2):
        if words[i].lower() in invalid_w:
            continue
        elif words[i][0].isupper():
            tmp.append(words[i]+" ")
        elif len(tmp)==0 :
            continue
        elif words[i+1][0].isupper():
            tmp.append(words[i])
        elif (len(tmp)<=2):
            tmp = []
        else:
            wor = re.split("[.,:?]","".join(tmp)[:-1])
            for wo in wor:
                if len(wo.split())<=2:
                    continue
                elif wo[0]==" ":
                    prop.append(clean_text(wo[1:]))
                else:
                    prop.append(clean_text(wo))
            tmp = []
    return prop

In [None]:
train_datasets = train.cleaned_label.unique()

In [None]:
extra_dsets = pd.read_csv('../input/bigger-govt-dataset-list/data_set_800.csv')

In [None]:
filtered = []
for x in extra_dsets.title:
    if len(x.split())>1:
        filtered.append(clean_text(x))

In [None]:
train_datasets = np.hstack([train_datasets,np.array(filtered)])

In [None]:
train_datasets = np.unique(train_datasets)

In [None]:
len(train_datasets)

In [None]:
def find_dataset(text):
    f = []
    for x in train_datasets:
        count = text.count(x)
        if count>0:
            f.append(x)
    return f   

In [None]:
def find_dataset_rule(entire):
    phrases = get_dataset_phrase(entire)
    useful = ['dataset','survey','data','database','study','atlas','collection','sequence','sequences']
    f= []
    for p in phrases:
        if p=="":
            continue
        fin = p.split()
        for u in useful:
            if u in fin:
                f.append(p)
                break
    ans = pd.DataFrame(f).value_counts()
    f = [z[0] for z in ans.keys()[:2]]
    return f

In [None]:
test_path = "../input/coleridgeinitiative-show-us-the-data/test"

In [None]:
test_files = os.listdir(test_path)

In [None]:
prediction = []
for file in tqdm(test_files):
    prediction.append("")

In [None]:
# test_files = [f for f in train_files]
# submission = pd.DataFrame([test_files]).T
# submission["Prediction"] = ""
# submission.columns  = ['Id', 'PredictionString']

In [None]:
test_files = [f[:-5] for f in test_files]
submission = pd.DataFrame([test_files,prediction]).T
submission.columns  = ['Id', 'PredictionString']

In [None]:
class DatasetFinder(nn.Module):
    
    def __init__(self,params):
        super().__init__()
        self.model = transformers.AutoModel.from_pretrained("../input/scibert-huggingface/coleridge-scibert-models/output")
        for param in self.model.parameters():
            param.requires_grad=False
        self.lstm = torch.nn.LSTM(input_size=768,hidden_size=params['hid_size'],bidirectional=True,batch_first=True)
        self.fc = torch.nn.Linear(params['hid_size']*2,1)
        self.e=0
    def forward(self,inp):
        inp,_ = self.lstm((self.model(**inp).last_hidden_state))
        inp = self.fc(inp).squeeze(2)
        return torch.sigmoid(inp)

In [None]:
tokenizer = transformers.AutoTokenizer.from_pretrained('../input/scibert-huggingface/coleridge-scibert-models/output')
DEVICE = 'cuda:0' if torch.cuda.is_available() else 'cpu'
M_COUNT=5

params = {
    'lr':0.001,
    'loss_func':torch.nn.BCELoss(),
    'hid_size':512
}

models = [DatasetFinder(params) for i in range(M_COUNT)]
for i in range(M_COUNT):
    models[i].load_state_dict(torch.load(f'../input/processed-train-data-sentence-segmentaion/model_best_{i}.state'))
    for params in models[i].parameters():
        params.requires_grad=False
    models[i].eval()
    models[i] = models[i].to(DEVICE)

In [None]:
def find_dataset_ann(text):
    t = text.split('.')
    p = 0
    inp_id = []
    attention_mask = []
    for x in t:
        com = tokenizer(x,max_length=30,padding="max_length",truncation=True)
        inp_id.append(torch.tensor(com["input_ids"],dtype=torch.long).view(1,-1))
        attention_mask.append(torch.tensor(com["attention_mask"],dtype=torch.long).view(1,-1))
#     ids = np.random.choice(np.arange(len(inp_id)),BATCH_SIZE)
#     if(len(inp_id)<=BATCH_SIZE):
#         ids = np.arange(len(inp_id))
    inp = {"input_ids":torch.cat(inp_id,0).to(DEVICE),
                 "attention_mask":torch.cat(attention_mask,0).to(DEVICE)}
    out = None
    for model in models:
        if out is None:
            out = model(inp).detach().cpu()
        else:
            out = out + model(inp).detach().cpu().numpy()
    out = out/M_COUNT
    ans = (inp['input_ids'].detach().cpu()*(out>0.35)).numpy()
    answers = []
    for i in ans:
        if(i.sum()>0):
            tmp = []
            for x in i:
                if x==0:
                    word = tokenizer.decode(tmp)
                    if len(word.split())>2:
                        answers.append(clean_text(word))
                    tmp = []
                else:
                    tmp.append(x)
            word = tokenizer.decode(tmp)
            if len(word.split())>2:
                answers.append(clean_text(word))
    if len(answers)==0:
        return []
#     c = []
#     for ans in answers:
#         c.append(text.count(ans))
#     answers = [answers[i] for i in np.argsort(c)[::-1][:3]]
    return answers    

In [None]:
dsets = []
ids = []
dset_comb = []
for i in tqdm(range(len(submission))):
        text = json.load(open(os.path.join(test_path,submission.iloc[i,0]+".json")))
        sec = []
        for x in text:
            sec.append(x['section_title'])
            sec.append(" ")    
            sec.append(x['text'])    
        entire = "".join(sec)
        pred1 = find_dataset_rule(entire)
        pred2 = find_dataset(clean_text(entire))
        pred3 = []
        try:
            pred3 = find_dataset_ann(entire)
        except:
            torch.cuda.empty_cache()
            print("error")
        pred = []
        for p in pred1: pred.append(p)
        for p in pred2: pred.append(p)
        for p in pred3: pred.append(p)
        pred = set(pred)
        dset_comb.append(pred)
        for p in pred:
            dsets.append(pred)
            ids.append(i)
datasets = pd.DataFrame()
datasets['ids']=ids
datasets['dsets'] = dsets

In [None]:
for i,spl in enumerate(dset_comb):
    n_spl = []
    for s in spl:
        if len(datasets[(datasets.ids!=i)&(datasets.dsets==s)])>0:
            n_spl.append(s)
    submission.loc[i,"PredictionString"] = "|".join(n_spl)

In [None]:
submission.to_csv("submission.csv",index=False)

In [None]:
submission