In [2]:
import pandas as pd
from pathlib import Path
import numpy as np
import torch
from transformers import ReformerModelWithLMHead, ReformerConfig, ReformerTokenizer
from transformers import AdamW
from transformers import get_linear_schedule_with_warmup
from torch.utils.data import DataLoader
torch.cuda.current_device()
torch.cuda._initialized = True

In [3]:
ocr_path = Path("/home/allekim/stonybook-data/hathi/ocr_model_results/double_books/")
result_paths = list(ocr_path.glob('*'))
df = pd.read_csv(result_paths[1], converters={'ctx1': eval, 'ctx2': eval, 'diff1': eval, 'diff2': eval})

In [4]:
def generate_examples(row):
    loss1, loss2 = row['loss1'], row['loss2']
    diff1, diff2 = row['diff1'], row['diff2']
    ctx1, ctx2 = row['ctx1'], row['ctx2']
    if any(['*' in x for x in ctx1]) or any(['*' in x for x in ctx2]):
        return np.nan
    ocr1, ocr2 = ctx1[diff1[0]:diff1[1]], ctx2[diff2[0]:diff2[1]]
    ex1 = ' '.join(ctx1[:diff1[0]]) + '*' + ' '.join(ocr1) + '*' + ' '.join(ctx1[diff1[1]:])
    ex2 = ' '.join(ctx2[:diff2[0]]) + '*' + ' '.join(ocr2) + '*' + ' '.join(ctx2[diff2[1]:])
    if loss1 < loss2:
        correct = '#' + ' '.join(ocr1) + '#'    
    else:
        correct = '#' + ' '.join(ocr2) + '#'
    return (ex1 + correct, ex2 + correct)


In [7]:
df['examples'] = df.apply(generate_examples, axis=1)

In [18]:
df.to_csv('test.csv')

In [8]:
result = [e for l in df['examples'].dropna() for e in l]

In [9]:
class OCRDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        item = {key: val[idx].clone().detach() for key, val in self.encodings.items()}
        item['labels'] = self.encodings['input_ids'][idx].clone().detach()
        return item

    def __len__(self):
        return len(self.encodings)


In [10]:
training_data = OCRDataset(encode(result[:450]))
test_data = OCRDataset(encode(result[450:]))
train_dataloader = DataLoader(training_data, batch_size=1, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=1, shuffle=True)

In [11]:
model = ReformerModelWithLMHead.from_pretrained("google/reformer-enwik8")
model.train()

ReformerModelWithLMHead(
  (reformer): ReformerModel(
    (embeddings): ReformerEmbeddings(
      (word_embeddings): Embedding(258, 1024)
      (position_embeddings): AxialPositionEmbeddings(
        (weights): ParameterList(
            (0): Parameter containing: [torch.FloatTensor of size 128x1x256]
            (1): Parameter containing: [torch.FloatTensor of size 1x512x768]
        )
      )
    )
    (encoder): ReformerEncoder(
      (layers): ModuleList(
        (0): ReformerLayer(
          (attention): ReformerAttention(
            (layer_norm): LayerNorm((1024,), eps=1e-12, elementwise_affine=True)
            (self_attention): LocalSelfAttention(
              (query): Linear(in_features=1024, out_features=1024, bias=False)
              (key): Linear(in_features=1024, out_features=1024, bias=False)
              (value): Linear(in_features=1024, out_features=1024, bias=False)
            )
            (output): ReformerSelfOutput(
              (dense): Linear(in_features=10

In [12]:
optimizer = AdamW(model.parameters(), lr=1e-5)
num_warmup_steps = 500
num_train_steps = 500
scheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_train_steps)

In [13]:
# device = torch.device("cuda:3")
# model = model.to(device)
# x = next(iter(train_dataloader))
# input_ids = x["input_ids"].to(device)
# attention_mask = x["attention_mask"].to(device)
# labels = x["labels"].to(device)

In [14]:
# output = model(input_ids, attention_mask=attention_mask, labels=labels)

In [15]:
# output.loss.backward()

In [16]:
# for ex in train_dataloader:
#     input_ids = ex["input_ids"].to(device)
#     attention_mask = ex["attention_mask"].to(device)
#     labels = ex["labels"].to(device)
#     outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
#     loss = outputs.loss
#     loss.backward()
#     optimizer.step()
#     scheduler.step()
#     print(loss)
#     break