In [1]:
import pickle
import pandas
import numpy as np


from transformers import T5Model, T5Tokenizer, AdamW
import pandas as pd
from torch.utils.data import Dataset , DataLoader
import pytorch_lightning as pl
from sklearn.model_selection import train_test_split

In [2]:
romantic_fileList = []
with open("../data/humor_train.p", mode="rb") as f:
    romantic_fileList = pickle.load(f)
humor_fileList = []
with open("../data/humor_train.p", mode="rb") as f:
    humor_fileList = pickle.load(f)
romantic_captions = []

In [3]:
original_df = pandas.read_csv("../data/Flickr8k.token.txt", sep="\t", header=None, names=["id", "sentence"])
original_df = original_df.assign(id=original_df.id.str.split("#").str[0])
# temporarily only use the first caption
original_df = original_df.groupby("id").agg(lambda x: x.iloc[0])
original_df

Unnamed: 0_level_0,sentence
id,Unnamed: 1_level_1
1000268201_693b08cb0e.jpg,A child in a pink dress is climbing up a set o...
1001773457_577c3a7d70.jpg,A black dog and a spotted dog are fighting
1002674143_1b742ab4b8.jpg,A little girl covered in paint sits in front o...
1003163366_44323f5815.jpg,A man lays on a bench while his dog sits by him .
1007129816_e794419615.jpg,A man in an orange hat starring at something .
...,...
990890291_afc72be141.jpg,A man does a wheelie on his bicycle on the sid...
99171998_7cc800ceef.jpg,A group is sitting around a snowy crevasse .
99679241_adc853a5c0.jpg,A grey bird stands majestically on a beach whi...
997338199_7343367d7f.jpg,A person stands near golden walls .


In [12]:
romantic_df = pandas.read_csv("../data/romantic_train.txt", names=["sentence"]).assign(id=romantic_fileList)
humor_df = pandas.read_csv("../data/funny_train.txt", names=["sentence"]).assign(id=humor_fileList)
original_df = original_df.merge(romantic_df, on="id", how="right", suffixes=["_original", "_romantic"]).merge(humor_df, on="id", how="right")
df = original_df.rename(columns={"sentence": "sentence_humor"})

In [41]:
from transformers import T5ForConditionalGeneration

t5 = T5ForConditionalGeneration.from_pretrained("t5-small")
tokenizer = T5Tokenizer.from_pretrained("t5-small")

In [42]:
class FlickrDataset(Dataset):
    def __init__(self,original,romantic,humor):
        self.original = original
        self.humor = humor
        self.romantic = romantic

    def __len__(self):
        return len(self.original)

    def to_token(self,sentence):
        return tokenizer.encode(sentence, max_length=30, truncation=True, padding="max_length", return_tensors="pt")

    def __getitem__(self,index):
        original_token = self.to_token(self.original.iloc[index])
        humor_token = self.to_token(self.humor.iloc[index])
        romantic_token = self.to_token(self.romantic.iloc[index])
        return original_token, humor_token, romantic_token

In [43]:
dataset = FlickrDataset(original_df.sentence_original,original_df.sentence_humor,original_df.sentence_romantic)

In [65]:
train_dataset, val_dataset = train_test_split(df, test_size=0.2)
batch_size = 32

In [66]:
class FlickrDataModule(pl.LightningDataModule):
    def __init__(self):
        super().__init__()
        self.train = FlickrDataset(train_dataset.sentence_original,train_dataset.sentence_humor,train_dataset.sentence_romantic)
        self.test = FlickrDataset(val_dataset.sentence_original,val_dataset.sentence_humor,val_dataset.sentence_romantic)
        self.val = FlickrDataset(val_dataset.sentence_original,val_dataset.sentence_humor,val_dataset.sentence_romantic)

    def train_dataloader(self):
        return DataLoader(self.train , batch_size = batch_size , shuffle = True)
    def test_dataloader(self):
        return DataLoader(self.test , batch_size = batch_size , shuffle = False)
    def val_dataloader(self):
        return DataLoader(self.val , batch_size = batch_size , shuffle = False)

In [67]:
next(iter(FlickrDataModule().train_dataloader()))

[tensor([[[   71,  4216,    11,   872,  2586,     3,  8623,    57,  3124,     3,
               5,     1,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0]],
 
         [[ 5245,   803,   492,  8519,    21,     8,  1861,     3,     5,     1,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0]],
 
         [[   71,  4595,   221,    26,   388,    19,  4125,    57,     3,     9,
            2309,  6998,  3609,    95,     3,     9, 21451,  1320,     3,     5,
               1,     0,     0,     0,     0,     0,     0,     0,     0,     0]],
 
         [[   71,  4940, 20584,     7,   112,   819,   190,     3,     9,  7445,
            1554,     3,     5,     1,     0,     0,     0,     0,     0,     0,
               0,     0,     0,     0,     0,     0,     0,     0,     0,     0]],
 
         [[ 

In [86]:
class StyleTransferModel(pl.LightningModule):
    def __init__(self):
        super().__init__()
        t5.train()
        self.neural_net = t5

    def forward(self, original, transferred):
        original = original.reshape(batch_size, -1)
        transferred = transferred.reshape(batch_size, -1)
        return self.neural_net(input_ids=original, labels=transferred)

    def configure_optimizers(self):
        return AdamW(self.parameters(), 1e-4)

    def training_step(self,batch,batch_idx):
        original, transferred = batch[0], batch[1]
        output = self(original, transferred)
        return output.loss

    def test_step(self,batch,batch_idx):
        original, transferred = batch[0], batch[1]
        output = self(original, transferred)
        return output.loss

    def validation_step(self,batch,batch_idx):
        original, transferred = batch[0], batch[1]
        output = self(original, transferred)
        return output.loss

In [87]:
from pytorch_lightning import Trainer
model = StyleTransferModel()
module = FlickrDataModule()
trainer = Trainer(max_epochs = 8)
trainer.fit(model,module)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs

  | Name       | Type                       | Params
----------------------------------------------------------
0 | neural_net | T5ForConditionalGeneration | 60.5 M
----------------------------------------------------------
60.5 M    Trainable params
0         Non-trainable params
60.5 M    Total params
242.026   Total estimated model params size (MB)


Validation sanity check: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validating: 0it [00:00, ?it/s]

RuntimeError: shape '[32, -1]' is invalid for input of size 720