In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from tqdm import tqdm
from accelerate import Accelerator


In [None]:
accelerator = Accelerator(precision="fp32")


In [None]:
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

train_dataset = datasets.MNIST(
    root="./data",
    train=True,
    download=True,
    transform=transform
)

train_loader = DataLoader(
    train_dataset,
    batch_size=128,
    shuffle=True
)


In [None]:
class SimpleFFN(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(28 * 28, 256)
        self.fc2 = nn.Linear(256, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        return self.fc2(x)


In [None]:
model = SimpleFFN()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)


In [None]:
model, optimizer, train_loader = accelerator.prepare(
    model, optimizer, train_loader
)


In [None]:
model.train()
losses = []
for epoch in range(3):
    for images, labels in tqdm(train_loader):
        optimizer.zero_grad()

        outputs = model(images)
        loss = F.cross_entropy(outputs, labels)
        losses.append(loss.item())
        accelerator.backward(loss)
        optimizer.step()

    accelerator.print(f"Epoch {epoch} done")


In [None]:
import pandas as pd


In [None]:
s = pd.Series(losses)
s.plot()


In [1]:
import random
import numpy as np

In [29]:
def generate_row():
    bucket_lengths = [16,32,64,128]
    bucket = random.randint(0,3)
    seq = np.random.random_integers(low=0, high=100, size=(bucket_lengths[bucket],))

    return {
        "bucket": bucket,
        "seq": seq
    }

In [30]:
import pandas as pd

In [35]:
rows = [generate_row() for i in range(16000)]
df = pd.DataFrame(rows[:8000])

  seq = np.random.random_integers(low=0, high=100, size=(bucket_lengths[bucket],))


In [36]:
import pyarrow as pa
import pyarrow.parquet as pq

In [37]:
df.to_parquet("../data/pq_2.parquet")

In [32]:
df['seq'][0].shape

(64,)