In [1]:
import os
import sys

In [2]:
import torch

In [3]:
from datasets import load_dataset

In [4]:
dataset = load_dataset("allenai/real-toxicity-prompts")

In [5]:
dataset = dataset['train']

In [6]:
import numpy as np
def get_random_tensors(l, hdim=100):
    """returns l x hdim tensor"""
    return torch.rand(l, hdim)
def get_y(datum):
    """returns rtp signal: bsize x 8"""
    cont = [d['continuation'] for d in datum]
    return torch.tensor(np.nan_to_num(np.array([list(l.values())[1:] for l in cont], dtype=np.float64)), dtype=torch.float64)

In [7]:
# lets batch encode a dataset:
class randomDataset(torch.utils.data.dataset.Dataset):
    def __init__(self, dataset, hdim=100):
        self.X = get_random_tensors(len(dataset), hdim)
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx]

### Putting it all together

In [8]:
from tqdm import tqdm

In [9]:
import wandb

In [10]:
# lets put it all together:
wandb.init(project="deepGenTest")
hdim, odim, bsize, epochs, grad_accm_steps = 100, 8, 64, 10, 1
model = torch.nn.Sequential(
    torch.nn.Linear(hdim, 150),
    torch.nn.ReLU(),
    torch.nn.Linear(150, 200),
    torch.nn.ReLU(),
    torch.nn.Linear(200, 300),
    torch.nn.ReLU(),
    torch.nn.Linear(300, 200),
    torch.nn.ReLU(),
    torch.nn.Linear(200, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, odim)
)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1, weight_decay=0)
loss_fn = torch.nn.CrossEntropyLoss()
Xdata = randomDataset(dataset)
ydl = torch.utils.data.dataloader.DataLoader(dataset, batch_size=bsize, shuffle=False, collate_fn=get_y)
Xdl = torch.utils.data.dataloader.DataLoader(Xdata, bsize, shuffle=False)

for e in range(epochs):
    epoch_loss = 0
    for ybatch, xbatch in tqdm(zip(ydl, Xdl), total = len(ydl), leave=True, desc=f"Epoch: {e}"):
        optimizer.zero_grad()
        fwd = model(xbatch)
        loss = loss_fn(fwd, ybatch)
        loss.backward()
        optimizer.step()
        wandb.log({"batch_loss" : loss.item()})
        epoch_loss += loss.item()
    wandb.log({"epoch_loss" : epoch_loss})
    print(f"with loss: {epoch_loss}")
    

[34m[1mwandb[0m: Currently logged in as: [33mzubin[0m ([33mlowercaselabs[0m). Use [1m`wandb login --relogin`[0m to force relogin


Epoch: 0: 100%|████████████████████████████| 1554/1554 [00:04<00:00, 315.70it/s]


with loss: 45460.008276687986


Epoch: 1: 100%|████████████████████████████| 1554/1554 [00:05<00:00, 308.65it/s]


with loss: 7155.70156301207


Epoch: 2: 100%|████████████████████████████| 1554/1554 [00:05<00:00, 305.58it/s]


with loss: 7155.756032602753


Epoch: 3: 100%|████████████████████████████| 1554/1554 [00:05<00:00, 309.98it/s]


with loss: 7155.766072790327


Epoch: 4: 100%|████████████████████████████| 1554/1554 [00:05<00:00, 307.89it/s]


with loss: 7155.767956759366


Epoch: 5: 100%|████████████████████████████| 1554/1554 [00:05<00:00, 304.68it/s]


with loss: 7155.768343551172


Epoch: 6: 100%|████████████████████████████| 1554/1554 [00:05<00:00, 307.42it/s]


with loss: 7155.768395236311


Epoch: 7: 100%|████████████████████████████| 1554/1554 [00:05<00:00, 308.18it/s]


with loss: 7155.768410863666


Epoch: 8: 100%|████████████████████████████| 1554/1554 [00:05<00:00, 309.39it/s]


with loss: 7155.768428974206


Epoch: 9: 100%|████████████████████████████| 1554/1554 [00:05<00:00, 308.10it/s]

with loss: 7155.768415547197





# Lets do this for civil_chat

In [11]:
civil = load_dataset("google/civil_comments")

Downloading readme:   0%|          | 0.00/7.73k [00:00<?, ?B/s]

Downloading data files:   0%|          | 0/3 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/194M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/187M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/21.0M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/20.8M [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/3 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/1804874 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/97320 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/97320 [00:00<?, ? examples/s]

In [12]:
len(civil)

3

In [13]:
next(iter(civil))

'train'

In [15]:
for k in civil.keys():
    print(len(civil[k]))

1804874
97320
97320


In [16]:
next(iter(civil['train']))

{'text': "This is so cool. It's like, 'would you want your mother to read this??' Really great idea, well done!",
 'toxicity': 0.0,
 'severe_toxicity': 0.0,
 'obscene': 0.0,
 'threat': 0.0,
 'insult': 0.0,
 'identity_attack': 0.0,
 'sexual_explicit': 0.0}

In [17]:
def civil_collate(datum):
    return torch.tensor(np.nan_to_num(np.array([list(l.values())[1:] for l in datum], dtype=np.float64)), dtype=torch.float64)

In [19]:
civil_collate([next(iter(civil['train']))])

tensor([[0., 0., 0., 0., 0., 0., 0.]], dtype=torch.float64)

In [21]:
# lets put it all together:
wandb.init(project="deepGenTest")
hdim, odim, bsize, epochs, grad_accm_steps = 100, 7, 64, 1, 1
model = torch.nn.Sequential(
    torch.nn.Linear(hdim, 150),
    torch.nn.ReLU(),
    torch.nn.Linear(150, 200),
    # torch.nn.ReLU(),
    # torch.nn.Linear(200, 300),
    # torch.nn.ReLU(),
    # torch.nn.Linear(300, 200),
    torch.nn.ReLU(),
    torch.nn.Linear(200, 100),
    torch.nn.ReLU(),
    torch.nn.Linear(100, odim)
)
optimizer = torch.optim.AdamW(model.parameters(), lr=0.1, weight_decay=0)
loss_fn = torch.nn.CrossEntropyLoss()
Xdata = randomDataset(civil['train'])
ydl = torch.utils.data.dataloader.DataLoader(civil['train'], batch_size=bsize, shuffle=False, collate_fn=civil_collate)
Xdl = torch.utils.data.dataloader.DataLoader(Xdata, bsize, shuffle=False)

for e in range(epochs):
    epoch_loss = 0
    for ybatch, xbatch in tqdm(zip(ydl, Xdl), total = len(ydl), leave=True, desc=f"Epoch: {e}"):
        optimizer.zero_grad()
        fwd = model(xbatch)
        loss = loss_fn(fwd, ybatch)
        loss.backward()
        optimizer.step()
        wandb.log({"batch_loss" : loss.item()})
        epoch_loss += loss.item()
    wandb.log({"epoch_loss" : epoch_loss})
    print(f"with loss: {epoch_loss}")
    

VBox(children=(Label(value='0.001 MB of 0.001 MB uploaded\r'), FloatProgress(value=1.0, max=1.0)))

0,1
batch_loss,▁▃▃▄▃▂▃▂▁▃▂▂▃▃▄▂▂▁▃▃▄▄▂▂▂▁▅▃▃▂▄▄▅▃█▄▂▃▃▄

0,1
batch_loss,0.15656


VBox(children=(Label(value='Waiting for wandb.init()...\r'), FloatProgress(value=0.011339938888947168, max=1.0…

Epoch: 0: 100%|██████████████████████████| 28202/28202 [00:46<00:00, 600.35it/s]

with loss: 9674.690395228059





In [27]:
y = next(iter(ydl))
model(next(iter(Xdl))).shape

torch.Size([64, 7])

In [28]:
y.shape

torch.Size([64, 7])