In [100]:
import copy

import numpy as np
import pandas as pd
import glob
import os
import tqdm

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader

from transformers import AutoModel, AutoTokenizer
from dtu_proj.data.dataset import UserDataset

%load_ext autoreload
%autoreload 2

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload
cpu


In [101]:
checkpoint = "prajjwal1/bert-tiny" # L=2, H=128
model = AutoModel.from_pretrained(checkpoint)
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [102]:
data_dir = '../data'
raw_dir = '../data/raw'
processed_dir = '../data/processed'

In [103]:
# os.listdir('../data/raw')
dataset = UserDataset(data_dir=data_dir)

In [123]:
dataset.books[dataset.books['Name'] == 'Their Eyes Were Watching God']

Unnamed: 0,Id,Name,Authors,Rating,PublishYear,Description
48548,1649067,Their Eyes Were Watching God,Zora Neale Hurston,3.93,2000,<strong>“A deeply soulful novel that comprehen...
76382,730364,Their Eyes Were Watching God,Zora Neale Hurston,3.92,2005,"At the age of 16, Janie is caught kissing the ..."
325908,885843,Their Eyes Were Watching God,Zora Neale Hurston,3.92,2000,One of the most important and enduring books o...
478200,2435581,Their Eyes Were Watching God,Zora Neale Hurston,3.93,1996,"First published in 1937, Their Eyes Were Watch..."
702336,904282,Their Eyes Were Watching God,Zora Neale Hurston,3.92,1978,"Fair and long-legged, independent and articula..."
758859,1162432,Their Eyes Were Watching God,Zora Neale Hurston,3.92,2008,“A deeply soulful novel that comprehends love ...
779788,3022240,Their Eyes Were Watching God,Zora Neale Hurston,3.93,2003,"A classic of black literature, this is the sto..."


In [131]:
# em = dataset.get_rating(1)
users = dataset[0]
users

     ID                                               Name  Rating
0     1  Agile Web Development with Rails: A Pragmatic ...       5
1     1  The Restaurant at the End of the Universe (Hit...       5
2     1                                         Siddhartha       5
3     1  The Clock of the Long Now: Time and Responsibi...       4
4     1            Ready Player One (Ready Player One, #1)       4
..   ..                                                ...     ...
473   1                                           The Firm       4
474   1      The Man With the Golden Gun (James Bond, #13)       5
475   1                                 The Water Princess       5
476   1                                  Paris to the Moon       4
477   1   Harry Potter Series Box Set (Harry Potter, #1-7)       5

[478 rows x 3 columns]


ValueError: a must be greater than 0 unless no samples are taken

In [73]:
em.shape

torch.Size([512])

In [7]:
class BERTClassifier(nn.Module):
    def __init__(self, num_classes, freeze_bert=False):
        super(BERTClassifier, self).__init__()
        # Instantiating BERT-based model object
        self.bert = AutoModel.from_pretrained(checkpoint)
        self.bert.config.problem_type = 'regression'

        # Defining layers like dropout and linear
        self.dropout = nn.Dropout(0.1)
        self.classifier = nn.Linear(self.bert.config.hidden_size, num_classes)
        self.softmax = nn.Softmax(dim=1)


    def forward(self, input_ids, attention_mask):

        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        
        # Feeding the input to BERT-based model to obtain contextualized representations
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)

        # Extracting the representations of [CLS] head
        last_hidden_state_cls = outputs.pooler_output  # shape = (B, 128)

        x = self.dropout(last_hidden_state_cls)
        
        # Feeding cls_rep to the classifier layer
        logits = self.classifier(x)

        return logits
    
model = BERTClassifier(num_classes=1).to(device)  # regression

In [41]:
loss = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-5)

with tqdm.tqdm(DataLoader(dataset, batch_size=64)) as pbar:
    for idx, batch in enumerate(pbar):
        batch = {k: v.to(device) for k, v in batch.items()}
        logits = model(batch['input_ids'], batch['attention_mask'])
        loss_value = loss(logits, batch['target'].unsqueeze(1))
        loss_value.backward()
        optimizer.step()
        optimizer.zero_grad()
        pbar.set_description(f"loss: {loss_value.item():.4f}")

loss: 2.3992:  15%|███████████████████▏                                                                                                             | 2719/18300 [04:55<28:28,  9.12it/s]

loss: 1.0274:  17%|█████████████████████▎                                                                                                           | 3027/18300 [05:30<27:45,  9.17it/s]


KeyboardInterrupt: 