In [1]:
import os
outdir = './p1'

In [3]:
if not os.path.exists(outdir):
    os.mkdir(outdir)

In [4]:
import urllib
urllib.request.urlretrieve('http://140.114.76.113:8000/pA1.csv', os.path.join(outdir, 'pA1.csv'))
urllib.request.urlretrieve('http://140.114.76.113:8000/pA2.csv', os.path.join(outdir, 'pA2.csv'))

('./p1\\pA2.csv', <http.client.HTTPMessage at 0x1ef47282208>)

In [7]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader

In [8]:
seed = 999
torch.manual_seed(seed)
torch.backends.cudnn.deterministic = True

In [14]:
class Data:
    def __init__(self, csv_path):
        super().__init__()
        self.anns = pd.read_csv(csv_path).to_dict('records')

    def __len__(self):
        return len(self.anns)

    def __getitem__(self, idx):
        ann = self.anns[idx]
        x = torch.tensor(ann['x'])
        y = torch.tensor(ann['y'])
        return x, y


class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.a = nn.Parameter(torch.rand(1) * 0.001)
        self.b = nn.Parameter(torch.rand(1) * 0.001)
    
    def forward(self, xs):
        ps = self.a * xs + self.b
        return ps


data = Data(os.path.join(outdir, 'pA1.csv'))
print(data.anns)
loader = DataLoader(data, batch_size=5)

device = 'cuda'
model = Net().to(device)
criterion = nn.L1Loss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-1)

history = {
    'loss': [],
    'a': [],
    'b': []
}

for epoch in range(50):
    for xs, ys in iter(loader):
        xs = xs.to(device)
        ys = ys.to(device)

        optimizer.zero_grad()
        ps = model(xs)
        loss = criterion(ps, ys)
        loss.backward()
        optimizer.step()

        history['loss'].append(loss.detach().item())
        history['a'].append(model.a.item())
        history['b'].append(model.b.item())

print(model.a)
print(model.b)

[{'x': 0.375, 'y': 6.107}, {'x': 0.951, 'y': 8.891}, {'x': 0.732, 'y': 7.88}, {'x': 0.599, 'y': 7.191}, {'x': 0.156, 'y': 4.829}, {'x': 0.156, 'y': 4.9910000000000005}, {'x': 0.057999999999999996, 'y': 4.085}, {'x': 0.866, 'y': 8.179}, {'x': 0.601, 'y': 6.778}, {'x': 0.708, 'y': 7.452999999999999}, {'x': 0.021, 'y': 4.047}, {'x': 0.97, 'y': 8.735}, {'x': 0.8320000000000001, 'y': 8.327}, {'x': 0.212, 'y': 4.99}, {'x': 0.182, 'y': 4.8}, {'x': 0.183, 'y': 4.938}, {'x': 0.304, 'y': 5.3420000000000005}, {'x': 0.525, 'y': 6.775}, {'x': 0.43200000000000005, 'y': 5.947}, {'x': 0.29100000000000004, 'y': 5.7}, {'x': 0.612, 'y': 7.195}, {'x': 0.139, 'y': 4.547}, {'x': 0.292, 'y': 5.212999999999999}, {'x': 0.366, 'y': 5.99}, {'x': 0.456, 'y': 6.3839999999999995}, {'x': 0.785, 'y': 8.04}, {'x': 0.2, 'y': 5.1339999999999995}, {'x': 0.514, 'y': 6.358}, {'x': 0.5920000000000001, 'y': 6.891}, {'x': 0.046, 'y': 4.04}, {'x': 0.608, 'y': 7.218999999999999}, {'x': 0.171, 'y': 4.914}, {'x': 0.065, 'y': 4.24