# How do we train the ESM model?

In [26]:
# import packages
import sys, os, glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, Subset

In [3]:
# load in the data
# make a toy dataset with 1000 samples
np.random.seed(0)
n = 1000 # number of samples
x = np.random.rand(n, 1)
y = 2 * x + 1 + (0.1 * np.random.randn(n, 1))

In [14]:
# turn the data into a torch tensor
x_tensor = torch.tensor(x, dtype=torch.float32)
y_tensor = torch.tensor(y, dtype=torch.float32)

In [15]:
# create a simple dataset class
class SimpleDataset(Dataset):

    def __init__(self, x_data, y_data):
        self.x_data = x_data
        self.y_data = y_data
    
    def __len__(self):
        return len(self.x_data)
    
    def __getitem__(self, idx):
        return self.x_data[idx], self.y_data[idx]

dataset = SimpleDataset(x_tensor, y_tensor)

In [36]:
# create a dataloader and split into train and test

train_prop = 0.8
train_size = int(train_prop * len(dataset))
train_idx = np.random.choice(len(dataset), train_size, replace=False)
test_idx = np.setdiff1d(np.arange(len(dataset)), train_idx)

train_dataset = Subset(dataset, train_idx)
test_dataset = Subset(dataset, test_idx)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)


In [42]:
for x_batch, y_batch in train_loader:
    print(x_batch.shape, y_batch.shape)
    break

torch.Size([32, 1]) torch.Size([32, 1])


In [44]:
# load in the model
class Net(nn.Module):

    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(1, 64)
        self.fc2 = nn.Linear(64, 1)
        self.act = nn.ReLU()
    
    def forward(self, x):

        x = self.fc1(x) # 1 -> 64
        x = self.act(x) # 64
        x = self.fc2(x) # 64 -> 1

        return x

model = Net()

In [52]:
# configure training, set up the optimizer
loss_fn = nn.MSELoss()
optimizer = optim.AdamW(model.parameters(), lr=1e-3)

In [None]:
# train the model