In [3]:
import os
import shutil
import yaml
import time, timeit
from collections import namedtuple
import numpy as np
import torch
from thinker.actor_net import ActorNet
from thinker.main import Env
import thinker.util as util
from thinker.self_play import init_env_out, create_env_out

class DetectBuffer:
    def __init__(self, outdir, t, rec_t, logger, delay_n=5):
        """
        Store training data grouped in planning stages and output
        whenever the target output is also readydd
            Args:
                N (int): number of planning stage per training output
                delay_n (int): number of planning stage delayed in the output y
                rec_t (int): number of step in a planning stage
                K (int): number of block to merge into
        """
        self.outdir = outdir
        self.t = t # number of time step per file
        self.rec_t = rec_t
        self.logger = logger        
        self.delay_n = delay_n        

        self.processed_n, self.xs, self.y, self.done, self.step_status = 0, [], [], [], []
        self.file_idx = -1
    
    def insert(self, xs, y, done, step_status):
        """
        Args:
            xs (dict): dictionary of training input, with each elem having the
                shape of (B, *)            
            y (tensor): bool tensor of shape (B), being the target output delayed by
                delay_n planning stage            
            done (tensor): bool tensor of shape (B), being the indicator of episode end
            step_status (int): int indicating current step status
        Output:
            save train_xs in shape (N, rec_t, B, *) and train_y in shape (N, B)
        """
        #print("data received! ", y.shape, id, cur_t)
        last_step_real = (step_status == 0) | (step_status == 3)
        if len(self.step_status) == 0 and not last_step_real: return self.file_idx  # skip until real step
                
        self.xs.append(util.dict_map(xs, lambda x:x.cpu()))
        self.y.append(y.cpu())
        self.done.append(done.cpu())
        self.step_status.append(step_status)
        self.processed_n += int(last_step_real)

        if (self.processed_n >= self.t + self.delay_n + 1):               
            self.file_idx += 1                     
            out = self._extract_data(self.t)
            self.processed_n = sum([int(i == 0) for i in self.step_status])
            assert self.processed_n == self.delay_n+1, f"should only have {self.delay_n + 1} data left instead of {self.processed_n}"
            path = f'{self.outdir}/data_{self.file_idx}.pt'
            torch.save(out, path)
            out_shape = out[0]['env_state'].shape
            n = self.file_idx * out_shape[0] * out_shape[2]
            self.logger.info(f"{n}: File saved to {path}; env_state shape {out_shape}")

        return self.file_idx   

    def _extract_data(self, t):
        # obtain the first N planning stage and the corresponding target_y in data
        xs, y, done, step_status = self._collect_data(t)
        future_y, future_done = self._collect_data(self.delay_n, y_done_only=True)
        y = torch.concat([y, future_y], dim=0)
        done = torch.concat([done, future_done], dim=0)                
        
        last_step_real = (step_status == 0) | (step_status == 3)
        assert last_step_real[0], "cur_t should start with 0"
        assert last_step_real.shape[0] == t*self.rec_t, \
            f" step_status.shape is {last_step_real.shape}, expected {t*self.rec_t} for the first dimension."        
        assert y.shape[0] == (t + self.delay_n)*self.rec_t, \
            f" y.shape is {y.shape}, expected {(t + self.delay_n)*self.rec_t} for the first dimension."        
        
        B = y.shape[1]
        y = y.view(t + self.delay_n, self.rec_t, B)[:, 0]
        done = done.view(t + self.delay_n, self.rec_t, B)[:, 0]
        step_status = step_status.view(t, self.rec_t)
        # compute target_y
        target_y = self._compute_target_y(y, done, self.delay_n)

        for k in xs.keys():
            xs[k] = xs[k].view((t, self.rec_t) + xs[k].shape[1:])
        
        xs["done"] = done[:t]
        xs["step_status"] = step_status
                
        return xs, target_y

    def _collect_data(self, t, y_done_only=False):
        # collect the first t stage from data
        step_status = torch.tensor(self.step_status, dtype=torch.long)
        next_step_real = (step_status == 2) | (step_status == 3)        
        idx = torch.nonzero(next_step_real, as_tuple=False).squeeze()    
        last_idx = idx[t-1] + 1
        y = torch.stack(self.y[:last_idx], dim=0)
        done = torch.stack(self.done[:last_idx], dim=0)
        if not y_done_only:
            xs = {}
            for k in self.xs[0].keys():
                xs[k] = torch.stack([v[k] for v in self.xs[:last_idx]], dim=0)                
            step_status = step_status[:last_idx]
            self.xs = self.xs[last_idx:]
            self.y = self.y[last_idx:]
            self.done = self.done[last_idx:]
            self.step_status = self.step_status[last_idx:]
            return xs, y, done, step_status
        else:
            return y, done
        
    def _compute_target_y(self, y, done, delay_n):        
        # target_y[i] = (y[i] | (~done[i+1] & y[i+1]) | (~done[i+1] & ~done[i+2] & y[i+2]) | ... | (~done[i+1] & ~done[i+2] & ... & ~done[i+M] & y[i+M]))
        t, b = y.shape
        t = t - delay_n
        not_done_cum = torch.ones(delay_n, t, b, dtype=bool)
        target_y = y.clone()[:-delay_n]
        not_done_cum[0] = ~done[1:1+t]
        target_y = target_y | (not_done_cum[0] & y[1:1+t])
        for m in range(1, delay_n):
            not_done_cum[m] = not_done_cum[m-1] & ~done[m+1:m+1+t]
            target_y = target_y | (not_done_cum[m] & y[m+1:m+1+t])
        return target_y

total_n = 100000
env_n = 128
delay_n = 5
savedir = "../logs/detect"
outdir = "../data/detect"
xpid = "v5_sok"

_logger = util.logger()
_logger.info(f"Initializing {xpid} from {savedir}")
device = torch.device("cuda")

ckpdir = os.path.join(savedir, xpid)     
if os.path.islink(ckpdir): ckpdir = os.readlink(ckpdir)  
ckpdir =  os.path.abspath(os.path.expanduser(ckpdir))
outdir = os.path.abspath(os.path.expanduser(outdir))

config_path = os.path.join(ckpdir, 'config_c.yaml')
flags = util.create_flags(config_path, save_flags=False)
disable_thinker = flags.wrapper_type == 1

env = Env(
        name=flags.name,
        env_n=env_n,
        gpu=True,
        train_model=False,
        parallel=False,
        savedir=savedir,        
        xpid=xpid,
        ckp=True,
        return_x=True,
        return_h=True,
    )

obs_space = env.observation_space
action_space = env.action_space 

actor_param = {
    "obs_space": obs_space,
    "action_space": action_space,
    "flags": flags,
    "tree_rep_meaning": env.get_tree_rep_meaning(),
}
actor_net = ActorNet(**actor_param)

path = os.path.join(ckpdir, "ckp_actor.tar")
checkpoint = torch.load(path, map_location=torch.device("cpu"))
actor_net.set_weights(checkpoint["actor_net_state_dict"])
actor_net.to(device)
actor_net.train(False)

state = env.reset()
env_out = init_env_out(state, flags=flags, dim_actions=actor_net.dim_actions, tuple_action=actor_net.tuple_action)  
actor_state = actor_net.initial_state(batch_size=env_n, device=device)

# create dir

n = 0
while True:
    name = "%s-%d-%d" % (xpid, checkpoint["real_step"], n)
    outdir_ = os.path.join(outdir, name)
    if not os.path.exists(outdir_):
        os.makedirs(outdir_)
        print(f"Outputting to {outdir_}")
        break
    n += 1
outdir = outdir_

detect_buffer = DetectBuffer(outdir=outdir, t=12800//env_n, rec_t=flags.rec_t, logger=_logger, delay_n=delay_n)
file_n = total_n // (env_n * detect_buffer.t) + 1
_logger.info(f"Data output directory: {outdir}")
_logger.info(f"Number of file to be generated: {file_n}")

rescale = "Sokoban" in flags.name

# save setting

env_state_shape = env.observation_space["real_states"].shape[1:]
#if rescale: env_state_shape = (3, 40, 40)
tree_rep_shape = env.observation_space["tree_reps"].shape[1:]

flags_detect = {
    "dim_actions": actor_net.dim_actions,
    "num_actions": actor_net.num_actions,
    "tuple_actions": actor_net.tuple_action,
    "name": flags.name,
    "env_state_shape": list(env_state_shape),
    "tree_rep_shape": list(tree_rep_shape),
    "rescale": rescale,
    "rec_t": flags.rec_t,
    "ckpdir": ckpdir,
}

yaml_file_path = os.path.join(outdir, 'config_detect.yaml')
with open(yaml_file_path, 'w') as file:
    yaml.dump(flags_detect, file)


Initializing v5_sok from ../logs/detect
Initializing env 0 with device cuda
Model network size: 6637133


Loaded config from ../logs/detect/v5_sok/config_c.yaml


Loaded model net from /mnt/c/Users/chung/Personal/RS/thinker/logs/detect/v5_sok/ckp_model.tar


Tree rep shape:  105
Tree rep meaning:  {'root_td': slice(0, 1, None), 'root_action': slice(1, 6, None), 'root_r': slice(6, 7, None), 'root_v': slice(7, 8, None), 'root_logits': slice(8, 13, None), 'cur_td': slice(13, 14, None), 'cur_action': slice(14, 19, None), 'cur_r': slice(19, 20, None), 'cur_v': slice(20, 21, None), 'cur_logits': slice(21, 26, None), 'cur_reset': slice(26, 27, None), 'one_hot_k': slice(27, 47, None), 'rollout_return': slice(47, 48, None), 'max_rollout_return': slice(48, 49, None), 'rollout_done': slice(49, 50, None), 'action_seq': slice(50, 75, None), 'root_action_table': slice(75, 100, None), 'root_td_table': slice(100, 105, None)}


Data output directory: /mnt/c/Users/chung/Personal/RS/thinker/data/detect/v5_sok-14137536-0
Number of file to be generated: 8


Outputting to /mnt/c/Users/chung/Personal/RS/thinker/data/detect/v5_sok-14137536-0


In [None]:
import torch.nn.functional as F

with torch.set_grad_enabled(False):
    
    while(True):

        actor_out, actor_state = actor_net(env_out=env_out, core_state=actor_state, greedy=False)
        if not disable_thinker:
            primary_action, reset_action = actor_out.action
        else:
            primary_action, reset_action = actor_out.action, None
        state, reward, done, info = env.step(
            primary_action=primary_action, 
            reset_action=reset_action, 
            action_prob=actor_out.action_prob[-1])    
        env_out = create_env_out(actor_out.action, state, reward, done, info, flags=flags)
        
        # write to detect buffer
        env_state = env_out.xs[0] 
        if rescale:
            #env_state = F.interpolate(env_state , size=(40, 40), mode='bilinear', align_corners=False)
            env_state = (env_state * 255).to(torch.uint8)

        pri_action = actor_out.action[0]
        reset_action = actor_out.action[1]
        tree_rep = state["tree_reps"]

        xs = {
            "env_state": env_state,
            "tree_rep": tree_rep,
            "pri_action": pri_action,
            "reset_action": reset_action,
        }
        y = info['cost']
        done = done
        step_status = info['step_status'][0].item()

        file_idx = detect_buffer.insert(xs, y, done, step_status)
        if file_idx >= file_n: 
            # last file is for validation
            os.rename(f'{outdir}/data_{file_idx}.pt', f'{outdir}/val.pt')
            break


In [1]:
import os
from torch.utils.data import Dataset, DataLoader
import yaml
import argparse


import math
import torch
from torch import nn
from torch.nn import functional as F
from thinker.model_net import BaseNet, FrameEncoder
from thinker import util

class CustomDataset(Dataset):
    def __init__(self, datadir, transform=None, data_n=None, prefix="data"):
        self.datadir = datadir
        self.transform = transform
        self.data = []        
        self.samples_per_file = None   
        self.data_n = data_n
        self.prefix = prefix
        self._preload_data(datadir)  # Preload data        

    def _preload_data(self, datadir):
        # Preload all .pt files
        file_list = [f for f in os.listdir(datadir) if f.endswith('.pt') and f.startswith(self.prefix)]
        for file_name in file_list:
            print(f"Starting to preload {file_name}")
            xs, y = torch.load(os.path.join(datadir, file_name))
            if self.samples_per_file is None:  # Set samples_per_file based on the first file
                self.t = xs['env_state'].shape[0]
                self.b = xs['env_state'].shape[2]
                self.samples_per_file = self.t * self.b
            xs.pop('step_status')
            xs.pop('done')
            # Flatten data across t and b dimensions for easier indexing
            for t_idx in range(self.t):
                for b_idx in range(self.b):
                    flattened_xs = {k: v[t_idx, :, b_idx] for k, v in xs.items()}
                    flattened_y = y[t_idx, b_idx]
                    self.data.append((flattened_xs, flattened_y))
                    if self.data_n is not None and len(self.data) >= self.data_n: return

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        xs, y = self.data[idx]
        if self.transform:
            # Apply transform if necessary. Note: You might need to adjust this part
            # based on what your transform expects and can handle
            xs = {k: self.transform(v) for k, v in xs.items()}            
        return xs, y

class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, max_len: int = 500):
        super().__init__()

        position = torch.arange(max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
        pe = torch.zeros(1, max_len, d_model)
        pe[0, :, 0::2] = torch.sin(position * div_term)
        pe[0, :, 1::2] = torch.cos(position * div_term)
        self.register_buffer('pe', pe)

    def forward(self, x):
        """
        Arguments:
            x: Tensor, shape ``[batch_size, seq_len, embedding_dim]``
        """
        x = x + self.pe[:,:x.size(1)]
        return x

class DetectFrameEncoder(nn.Module):
    def __init__(
        self,
        input_shape,     
        dim_rep_actions,   
        out_size=128,
        stride=2,
    ):  
        super(DetectFrameEncoder, self).__init__()
        self.out_size = out_size
        self.encoder = FrameEncoder(prefix="se",
                                    actions_ch=dim_rep_actions,
                                    input_shape=input_shape,                             
                                    size_nn=1,             
                                    downscale_c=2,    
                                    concat_action=False)
        
        self.conv = []
        in_ch =  self.encoder.out_shape[0]
        for ch in [64]:
            self.conv.append(nn.ReLU())
            self.conv.append(nn.Conv2d(in_channels=in_ch,
                                       out_channels=ch,
                                       kernel_size=3,
                                       stride=1,
                                       padding=1,))
            in_ch = ch
        self.conv = nn.Sequential(*self.conv)
        conv_out_size = in_ch * self.encoder.out_shape[1] * self.encoder.out_shape[2]
        self.fc = nn.Sequential(nn.Linear(conv_out_size, self.out_size))       

    def forward(self, x, action):
        # x in shape of (B, C, H, W)
        out, _ = self.encoder(x, done=None, actions=action, state={})
        out = self.conv(out)
        out = torch.flatten(out, start_dim=1)
        out = self.fc(out)
        return out                                
        
class DetectNet(BaseNet):
    def __init__(
        self,
        env_state_shape,
        tree_rep_shape,
        dim_actions,
        num_actions,
        detect_ab=(0,0),
        clone=False,
        tran_layer_n=3,
    ):    
        super(DetectNet, self).__init__()
        
        self.env_state_shape = env_state_shape # in (C, H, W) 
        self.tree_rep_shape = tree_rep_shape # in (C,) 
        self.dim_actions = dim_actions
        self.num_actions = num_actions
        self.dim_rep_actions = self.dim_actions if self.dim_actions > 1 else self.num_actions

        self.detect_ab = detect_ab
        self.clone = clone

        self.enc_out_size = 128        
        tran_nhead = 8
        reminder = tran_nhead - ((self.enc_out_size + tree_rep_shape[0] + self.dim_rep_actions + 1) % tran_nhead)
        self.enc_out_size += reminder
        self.true_x_encoder = DetectFrameEncoder(input_shape=env_state_shape, dim_rep_actions=self.dim_rep_actions, out_size=self.enc_out_size)
        self.pred_x_encoder = DetectFrameEncoder(input_shape=env_state_shape, dim_rep_actions=self.dim_rep_actions, out_size=self.enc_out_size)

        self.embed_size = self.enc_out_size + tree_rep_shape[0] + num_actions + 1
        self.pos_encoder = PositionalEncoding(self.embed_size)

        encoder_layer = nn.TransformerEncoderLayer(d_model=self.embed_size, 
                                                   nhead=tran_nhead, 
                                                   dim_feedforward=512,
                                                   batch_first=True)
        self.transformer_encoder = nn.TransformerEncoder(encoder_layer, tran_layer_n)
        self.classifier = nn.Linear(self.embed_size, 1)

        self.beta = nn.Parameter(torch.tensor(0.5), requires_grad=False) # portion of negative class

    def forward(self, env_state, tree_rep, action, reset):
        """
        Forward pass of detection nn
        Args:
            env_state: float Tensor in shape of (B, rec_t, C, H, W); true and predicted frame
            tree_rep: float Tensor in shape of (B, rec_t, C); model output
            action: uint Tensor in shape of (B, rec_t, dim_actions); action (real / imaginary)
            reset: bool Tensor in shape of  (B, rec_t); reset action
        Return:
            logit: float Tensor in shape of (B); logit of classifier output
            p: float Tensor in shape of (B); prob of classifier output
        """
        B, rec_t = env_state.shape[:2]
        if self.detect_ab[0] in [1, 3] or self.detect_ab[1] in [1, 3]:
            if self.clone: env_state = env_state.clone()                
            if self.detect_ab[0] in [1, 3]:
                env_state[:, 0] = 0.
            if self.detect_ab[1] in [1, 3]:
                env_state[:, 1:] = 0.
        if self.detect_ab[0] in [2, 3] or self.detect_ab[1] in [2, 3]:
            if self.clone: tree_rep = tree_rep.clone()
            if self.detect_ab[0] in [2, 3]:
                tree_rep[:, 0] = 0.
            if self.detect_ab[1] in [2, 3]:
                tree_rep[:, 1:] = 0.
        
        action = util.encode_action(action, self.dim_actions, self.num_actions)        
        true_proc_x = self.true_x_encoder(env_state[:,0], action[:,0])
        pred_proc_x = self.pred_x_encoder(
            torch.flatten(env_state[:,1:], 0, 1),
            torch.flatten(action[:,1:], 0, 1)
                                        )
        true_proc_x = true_proc_x.view(B, self.enc_out_size).unsqueeze(1) # (B, 1, C)
        pred_proc_x = pred_proc_x.view(B, rec_t - 1, self.enc_out_size)  # (B, rec_t - 1, C)
        proc_x = torch.concat([true_proc_x, pred_proc_x], dim=1) # (B, rec_t, C)
        
        embed = [proc_x, tree_rep, action, reset.unsqueeze(-1)]
        embed = torch.concat(embed, dim=2) # (B, rec_t, embed_size)
        embed_pos = self.pos_encoder(embed)
        out = self.transformer_encoder(embed_pos)
        logit = self.classifier(out[:, -1, :]).view(B)
        return logit, torch.sigmoid(logit)

def transform_data(xs, device):
    xs_ = {}

    env_state = xs["env_state"]
    if flags_data.rescale:
        env_state = env_state.float() / 255
    xs_["env_state"] = env_state.to(device)

    if "tree_rep" in xs: xs_["tree_rep"] = xs["tree_rep"].to(device)

    action = xs["pri_action"]
    if not flags_data.tuple_actions:
        action = action.unsqueeze(-1)
    xs_["action"] = action.to(device)

    if "reset_action" in xs: xs_["reset"] = xs["reset_action"].to(device)
    return xs_

def evaluate_detect(target_y, pred_y):
    # Binarize the predictions
    pred_y_binarized = (pred_y > 0.5).float()
    target_y = target_y.float()

    # Compute the accuracy
    acc = torch.mean((pred_y_binarized == target_y).float()).item()
    
    # Compute the recall
    true_positives = (pred_y_binarized * target_y).sum().float()
    possible_positives = target_y.sum().float()
    rec = (true_positives / (possible_positives + 1e-6)).item()
    
    # Compute the precision
    predicted_positives = pred_y_binarized.sum().float()
    prec = (true_positives / (predicted_positives + 1e-6)).item()
    
    # Compute the F1 score
    f1 = 2 * (prec * rec) / (prec + rec + 1e-6)   

    neg_p = 1 - torch.mean(target_y.float()).item()

    return {
        "acc": acc,
        "rec": rec,
        "prec": prec,
        "f1": f1,
        "neg_p": neg_p,
        }

def train_epoch(detect_net, dataloader, optimizer, device, flags, train=True):
    if train:
        detect_net.train()
    else:
        detect_net.eval()     
    running_train_eval = {}   
    with torch.set_grad_enabled(train):
        step = 0
        for xs, target_y in dataloader:
            xs = transform_data(xs, device)
            target_y = target_y.to(device)
            
            logit, pred_y = detect_net(**xs)
            n_mean_y = torch.mean((~target_y).float()).item()
            detect_net.beta.data = 0.99 * detect_net.beta.data + (1 - 0.99) * n_mean_y
            detect_net.beta.data.clamp_(0.05, 0.95)
            weights = torch.where(target_y == 1, detect_net.beta.data, 1-detect_net.beta.data)
            loss = F.binary_cross_entropy_with_logits(logit, target_y.float(), weight=weights)
            train_eval = evaluate_detect(target_y, pred_y)
            train_eval["loss"] = loss.item()

            if train:
                optimizer.zero_grad()  # Zero the gradients
                loss.backward()  # Backward pass: compute gradient of the loss with respect to model parameters
                optimizer.step()  # Perform a single optimization step (parameter update)
            
            for key in train_eval.keys():
                if key not in running_train_eval: 
                    running_train_eval[key] = train_eval[key]
                else:
                    running_train_eval[key] += train_eval[key]
            step += 1
    return {key: val / step for (key, val) in running_train_eval.items()}





In [3]:

flags = argparse.Namespace()

flags.datadir = "../data/detect/v5_sok-14137536-0/"
flags.xpid = "test"
flags.batch_size = 128
flags.learning_rate = 0.0001
flags.num_epochs = 100
flags.data_n = 50000
flags.ckp = False

if not flags.ckp:
    flags.datadir = os.path.abspath(os.path.expanduser(flags.datadir))
    # create ckp dir
    xpid_n = 0
    while (True):
        xpid_ = flags.xpid if xpid_n == 0 else flags.xpid + f"_{xpid_n}"
        ckpdir = os.path.join(flags.datadir, xpid_)
        xpid_n += 1
        if not os.path.exists(ckpdir):
            os.mkdir(ckpdir) 
            flags.xpid = xpid_
            break    
else:
    ckpdir = os.path.join(flags.datadir, flags.xpid)
flags.ckpdir = ckpdir
flags.ckp_path = os.path.join(ckpdir, "ckp_detect.tar")
print(f"Checkpoint path: {flags.ckp_path}")

# load data
dataset = CustomDataset(datadir=flags.datadir, transform=None, data_n=flags.data_n)
dataloader = DataLoader(dataset, batch_size=flags.batch_size, shuffle=True)

val_dataset = CustomDataset(datadir=flags.datadir, transform=None, data_n=2000, prefix="val")
val_dataloader = DataLoader(val_dataset, batch_size=flags.batch_size, shuffle=True)

# load setting
yaml_file_path = os.path.join(flags.datadir, 'config_detect.yaml')
with open(yaml_file_path, 'r') as file:
    flags_data = yaml.safe_load(file)
flags_data = argparse.Namespace(**flags_data)
flags = argparse.Namespace(**{**vars(flags), **vars(flags_data)}) # merge the two flags

Checkpoint path: /mnt/c/Users/chung/Personal/RS/thinker/data/detect/v5_sok-14137536-0/test/ckp_detect.tar
Starting to preload data_0.pt
Starting to preload data_1.pt
Starting to preload data_2.pt
Starting to preload data_3.pt
Starting to preload val.pt


In [4]:
# initalize net
device = torch.device("cuda")
detect_net = DetectNet(
    env_state_shape = flags_data.env_state_shape,
    tree_rep_shape = flags_data.tree_rep_shape,
    dim_actions = flags_data.dim_actions,
    num_actions = flags_data.num_actions,
)

# load optimizer
optimizer = torch.optim.Adam(
    detect_net.parameters(), lr=flags.learning_rate, 
)

if flags.ckp:
    checkpoint = torch.load(flags.ckp_path, torch.device("cpu"))
    optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
    detect_net.load_state_dict(checkpoint["net_state_dict"])
    epoch = checkpoint["epoch"]
    del checkpoint
else:
    epoch = 0

detect_net = detect_net.to(device)
util.optimizer_to(optimizer, device)

In [5]:
while (epoch < flags.num_epochs):
    running_loss = 0.0
    running_train_eval = {}

    train_stat = train_epoch(detect_net, dataloader, optimizer, device, flags, train=True)
    val_stat = train_epoch(detect_net, val_dataloader, None, device, flags, train=False)

    print_str = f'Epoch {epoch+1}/{flags.num_epochs},'
    for key in train_stat.keys(): 
        print_str += f" {key}:{train_stat[key]:.4f} ({val_stat[key]:.4f})"
    print(print_str)

    epoch += 1    
    if epoch % 5 == 0 or epoch >= flags.num_epochs:
        # save checkpoint
        d = {
            "epoch": epoch,
            "flags": flags,
            "optimizer_state_dict": optimizer.state_dict(),
            "net_state_dict": detect_net.state_dict(),
        }
        torch.save(d, flags.ckp_path)
        print(f"Checkpoint saved to {flags.ckp_path}")

Epoch 1/100, acc:0.8860 (0.8019) rec:0.6524 (0.9215) prec:0.4468 (0.3401) f1:0.5092 (0.4894) neg_p:0.8755 (0.8934) loss:0.0991 (0.0748)
Epoch 2/100, acc:0.9169 (0.9297) rec:0.9284 (0.7902) prec:0.6207 (0.6574) f1:0.7336 (0.7077) neg_p:0.8755 (0.8925) loss:0.0425 (0.0711)
Epoch 3/100, acc:0.9340 (0.9204) rec:0.9470 (0.8364) prec:0.6729 (0.5887) f1:0.7789 (0.6869) neg_p:0.8755 (0.8937) loss:0.0336 (0.0759)
Epoch 4/100, acc:0.9429 (0.8788) rec:0.9612 (0.8676) prec:0.7017 (0.4622) f1:0.8042 (0.5999) neg_p:0.8755 (0.8939) loss:0.0281 (0.0607)
Epoch 5/100, acc:0.9514 (0.8806) rec:0.9708 (0.8798) prec:0.7376 (0.4695) f1:0.8313 (0.6008) neg_p:0.8755 (0.8939) loss:0.0240 (0.0758)
Checkpoint saved to /mnt/c/Users/chung/Personal/RS/thinker/data/detect/v5_sok-14137536-0/test/ckp_detect.tar
Epoch 6/100, acc:0.9595 (0.8639) rec:0.9761 (0.8884) prec:0.7695 (0.4316) f1:0.8548 (0.5732) neg_p:0.8755 (0.8931) loss:0.0200 (0.0827)
Epoch 7/100, acc:0.9651 (0.9388) rec:0.9816 (0.6830) prec:0.7943 (0.6932) f

KeyboardInterrupt: 

In [14]:
# deprecated

import os 
from torch.utils.data import Dataset, DataLoader
import torch
from thinker import util

datadir = "../data/detect/v5_sok-5993808-1/"
datadir = os.path.abspath(os.path.expanduser(datadir))

class CustomDataset(Dataset):
    def __init__(self, datadir, transform=None):
        self.datadir = datadir
        self.file_list = [f for f in os.listdir(datadir) if f.endswith('.pt')]
        self.transform = transform
        xs, y = torch.load(os.path.join(datadir, self.file_list[0]))        
        self.t = xs['env_state'].shape[0]
        self.b = xs['env_state'].shape[2]
        self.samples_per_file = self.t * self.b

    def __len__(self):
        return len(self.file_list) * self.samples_per_file  # Adjust based on your data

    def __getitem__(self, idx):
        file_idx = idx // self.samples_per_file
        within_file_idx = idx % self.samples_per_file
        t_idx = within_file_idx // self.b
        b_idx = within_file_idx % self.b
        xs, y = torch.load(os.path.join(self.datadir, self.file_list[file_idx]))
        xs.pop('step_status')
        xs.pop('done')
        xs = util.dict_map(xs, lambda x: x[t_idx, :, b_idx])
        y = y[t_idx, b_idx]
        return xs, y

# To load data and train
dataset = CustomDataset(datadir)
# print(dataset[100])
train_dataloader = DataLoader(dataset, batch_size=64, shuffle=True)
train_features, train_labels = next(iter(train_dataloader))


0.2288818359375