In [2]:
import accelerate
import torch
from torch import nn
from torch.utils.data import TensorDataset, DataLoader
from torch.optim import AdamW

In [5]:
def train():
    n = 10000
    accelerator = accelerate.Accelerator()
    embeddings = torch.randn((8, n, 1024))
    labels = torch.sigmoid(torch.randn((8, n,n)))
    dataset = TensorDataset(embeddings, labels)
    dataloader = DataLoader(dataset, batch_size=1)
    model = nn.Linear(1024, 1024)
    optimizer = AdamW(model.parameters(), lr=1e-3)
    model, dataloader, optimizer = accelerator.prepare(
        model, dataloader, optimizer)
    bce = nn.BCEWithLogitsLoss()
    model.train()
    for i, (batch_embeddings, batch_labels) in enumerate(dataloader):
        batch_embeddings = batch_embeddings.squeeze()
        batch_labels = batch_labels.squeeze()
        print(f"batch {i + 1:2d}: "
                          f"embeddings: {batch_embeddings.shape}, "
                          f"labels: {batch_labels.shape}")
        out = model(batch_embeddings).mm(batch_embeddings.T)
        loss = bce(out, batch_labels)
        print(f"batch {i + 1:2d}: loss = {loss.item()}")
        accelerator.backward(loss)
        optimizer.step()
        optimizer.zero_grad()

In [6]:
accelerate.notebook_launcher(train, num_processes=1)

Launching training on one GPU.
batch  1: embeddings: torch.Size([10000, 1024]), labels: torch.Size([10000, 10000])
batch  1: loss = 7.4055681228637695
batch  2: embeddings: torch.Size([10000, 1024]), labels: torch.Size([10000, 10000])
batch  2: loss = 7.086696147918701
batch  3: embeddings: torch.Size([10000, 1024]), labels: torch.Size([10000, 10000])
batch  3: loss = 6.776832103729248
batch  4: embeddings: torch.Size([10000, 1024]), labels: torch.Size([10000, 10000])
batch  4: loss = 6.458182334899902
batch  5: embeddings: torch.Size([10000, 1024]), labels: torch.Size([10000, 10000])
batch  5: loss = 6.1594557762146
batch  6: embeddings: torch.Size([10000, 1024]), labels: torch.Size([10000, 10000])
batch  6: loss = 5.851268291473389
batch  7: embeddings: torch.Size([10000, 1024]), labels: torch.Size([10000, 10000])
batch  7: loss = 5.554714679718018
batch  8: embeddings: torch.Size([10000, 1024]), labels: torch.Size([10000, 10000])
batch  8: loss = 5.257957935333252


In [7]:
a = torch.randn((5, 2))

In [9]:
a[:,0].shape

torch.Size([5])