In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np

In [2]:
class SAE(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.encoder = nn.Linear(in_features=input_size, out_features=hidden_size, bias=True)
        self.thresh = nn.Parameter(torch.zeros(hidden_size), requires_grad=True)
        self.decoder = nn.Linear(in_features=hidden_size, out_features=input_size, bias=True)

    def encode(self, x):
        y = self.encoder(x)
        mask = (y > self.thresh)
        y = mask * nn.functional.relu(y)
        return y

    def decode(self, x):
        y = self.decoder(x)
        return y

    def forward(self, x):
        y = self.encode(x)
        y = self.decode(y)
        return y

In [3]:
dataset = np.load("/data/mech/data/output/official_it_04_clipped_1_99_percentile_scaled_0_to_1_normalized.npy", mmap_mode="r")
train_loader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=False)

In [4]:
model = SAE(input_size=768, hidden_size=8888).cuda()

In [5]:
state_dict = torch.load("/home/john/Downloads/experiments/models/exp001.pth")
model.load_state_dict(state_dict)

  state_dict = torch.load("/home/john/Downloads/experiments/models/exp001.pth")


<All keys matched successfully>

In [6]:
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=1e-5)

In [6]:
for epoch in range(100):
    count = 0
    for data in train_loader:
        optimizer.zero_grad()
        data = data.cuda()
        outputs = model(data)
        loss = criterion(outputs, data)
        loss.backward()
        optimizer.step()
        if count % 1000 == 0:
            print(f"{epoch} [{count}]: {loss.cpu().item()}")
        count += 1

  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)


0 [0]: 1.0899147987365723
0 [1000]: 0.061311617493629456
0 [2000]: 0.031246613711118698
0 [3000]: 0.03321358561515808
0 [4000]: 0.01398485153913498
0 [5000]: 0.01865777187049389
0 [6000]: 0.014735085889697075
1 [0]: 0.015443718060851097
1 [1000]: 0.012435225769877434
1 [2000]: 0.012501420453190804
1 [3000]: 0.018413931131362915
1 [4000]: 0.007246832828968763
1 [5000]: 0.010215869173407555
1 [6000]: 0.00809275358915329
2 [0]: 0.009059442207217216
2 [1000]: 0.007488461211323738
2 [2000]: 0.007562537677586079
2 [3000]: 0.012025205418467522
2 [4000]: 0.005612907465547323
2 [5000]: 0.007165712770074606
2 [6000]: 0.005550870671868324
3 [0]: 0.007082977797836065
3 [1000]: 0.005251803435385227
3 [2000]: 0.005681280046701431
3 [3000]: 0.008611153811216354
3 [4000]: 0.004479141440242529
3 [5000]: 0.0058991629630327225
3 [6000]: 0.004751142580062151
4 [0]: 0.006344749592244625
4 [1000]: 0.004373305477201939
4 [2000]: 0.004868806805461645
4 [3000]: 0.0070645930245518684
4 [4000]: 0.004008993040770

In [8]:
data = next(iter(train_loader))
inputs_temp = data[:10].cuda()
with torch.no_grad():
    outputs_temp = model(inputs_temp.cuda())

  return collate([torch.as_tensor(b) for b in batch], collate_fn_map=collate_fn_map)


In [9]:
inputs_temp.shape, outputs_temp.shape

(torch.Size([10, 768]), torch.Size([10, 768]))

In [10]:
1 - torch.mean((inputs_temp - outputs_temp)**2) / torch.var(inputs_temp)

tensor(0.9963, device='cuda:0')

In [10]:
torch.save(model.state_dict(), "/home/john/Downloads/experiments/models/exp001.pth")

In [18]:
means_new = np.load("/data/mech/data/output/official_it_04_clipped_1_99_percentile_scaled_0_to_1_mean.npy")
means_old = np.load("/data/mech/data/train/official_it_04_clipped_1_99_percentile_scaled_0_to_1_mean.npy")
print(np.abs((means_new - means_old)).sum() / max(means_new.sum(), means_old.sum()))
print(means_new.sum(), means_old.sum())

0.012523366
384.06372 383.6629


In [19]:
std_new = np.load("/data/mech/data/output/official_it_04_clipped_1_99_percentile_scaled_0_to_1_std.npy")
std_old = np.load("/data/mech/data/train/official_it_04_clipped_1_99_percentile_scaled_0_to_1_std.npy")
print(np.abs(std_new - std_old).sum() / max(std_new.sum(), std_old.sum()))
print(std_new.sum(), std_old.sum())

0.0073449966
158.74097 158.87686
