In [60]:
import copy

import numpy as np
import pandas as pd
import glob
import os
import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoModel, AutoTokenizer
from dtu_proj.data.dataset import UserDataset
from dtu_proj.models.model import BERTClassifier

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
cpu


In [61]:
checkpoint = "prajjwal1/bert-tiny" # L=2, H=128
model = AutoModel.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [62]:
data_dir = '../data'
raw_dir = '../data/raw'
processed_dir = '../data/processed'

In [63]:
books = pd.read_csv(os.path.join(data_dir, 'processed', 'books.csv'))
books

Unnamed: 0,name,book_id,authors,avg_rating,publish_year,desc
0,!%@:: A Directory of Electronic Mail Addressin...,3682820,Rick Adams,0.00,1993,"These days, it's a rare person who hasn't hear..."
1,!Arrebátalo!: La fe que se mantiene firme ante...,3001790,Judy Jacobs,4.00,2006,<i>¡Aprende a caminar en el poder de la fe vio...
2,"!Buen viaje!, Course 1, Student Edition",3353252,McGraw-Hill Education,2.80,1999,This three-level program brings you every reso...
3,!Búscalo! (Look It Up!): A Quick Reference Gui...,1404770,William M. Clarkson,3.82,1998,A novel approach--very useful for quick refere...
4,!Click Song,1621343,John A. Williams,3.80,1987,"At lunch in a restaurant on the East Side, Cat..."
...,...,...,...,...,...,...
1064895,�� 305-310; Uklag: (recht Der Allgemeinen Gesc...,1373998,Michael Coester,0.00,2006,<br />Das Recht der AGB-Inhaltskontrolle hat d...
1064896,"�� 328-359: (vertrag Zugunsten Dritter, R�cktr...",4725551,Julius von Staudinger,0.00,2004,Der Band erl�utert die Vorschriften �ber den V...
1064897,�� 491-507: (verbraucherdarlehen),4725546,Julius von Staudinger,0.00,2004,Der Band enth�lt eine geschlossene Darstellung...
1064898,�� 50-127a,3544662,Rainer Hausmann,0.00,1994,The Wieczorek/Sch�tze commentary covers German...


In [64]:
ratings = pd.read_csv(os.path.join(data_dir, 'processed', 'ratings.csv'))
ratings

Unnamed: 0,user_id,name,rating
0,1,The Restaurant at the End of the Universe (Hit...,5
1,1,Siddhartha,5
2,1,"The Hunger Games (The Hunger Games, #1)",5
3,1,The Authoritative Calvin and Hobbes: A Calvin ...,5
4,1,The Return of the Indian (The Indian in the Cu...,5
...,...,...,...
179669,5403,Homeport,3
179670,5403,Irish Hearts (Irish Hearts #1 & 2),3
179671,5403,"Brazen Virtue (D.C. Detectives, #2)",2
179672,5403,"Dance Upon The Air (Three Sisters Island, #1)",4


In [85]:
# os.listdir('../data/raw')
dataset = UserDataset(data_dir=data_dir, tokenizer=tokenizer)

In [93]:
dataset[0]['datas'].iloc[0]

name            The Restaurant at the End of the Universe (Hit...
book_id                                                    862825
authors                                             Douglas Adams
avg_rating                                                   4.22
publish_year                                                 1981
desc            Just when you thought it was safe to go back t...
Name: 0, dtype: object

In [67]:
bert_model = BERTClassifier(num_classes=1, device='cpu')  # regression

In [74]:
with tqdm.tqdm(DataLoader(dataset, batch_size=1)) as pbar:
    for idx, batch in enumerate(pbar):
        break

  0%|          | 0/3850 [00:07<?, ?it/s]


In [76]:
batch.keys()

dict_keys(['book_embed', 'book_attention_mask', 'book_id', 'rating'])

In [79]:
script_model = torch.jit.trace(bert_model, (batch['book_embed'], batch['book_attention_mask']))

TypeError: BERTClassifier.forward() takes 2 positional arguments but 3 were given

In [8]:
loss = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

with tqdm.tqdm(DataLoader(dataset, batch_size=1)) as pbar:
    for idx, batch in enumerate(pbar):
        batch = {k: batch[k].to(device) for k in batch.keys()}
        logits = bert_model(batch)
        loss_value = loss(logits, batch['rating'])
        loss_value.backward()
        optimizer.step()
        optimizer.zero_grad()
        pbar.set_description(f"loss: {loss_value.item():.4f}")

  0%|          | 0/3850 [00:00<?, ?it/s]

input_ids.shape=torch.Size([1, 829, 512])
output.shape=torch.Size([829, 128])
logits.shape=torch.Size([1, 829])


loss: 1.8225:   0%|          | 1/3850 [00:17<19:02:57, 17.82s/it]

input_ids.shape=torch.Size([1, 158, 512])


loss: 1.8225:   0%|          | 1/3850 [00:19<21:08:10, 19.77s/it]


KeyboardInterrupt: 

In [29]:
torch.transpose(logits, 0, 1).shape

torch.Size([1, 829])