In [1]:
import torch
import trimesh
import numpy as np
import os
import csv 
import json
import math 

from meshgpt_pytorch import (
    MeshTransformerTrainer,
    MeshAutoencoderTrainer,
    MeshAutoencoder,
    MeshTransformer
)

def get_3d_data(file_path): 
    mesh = trimesh.load(file_path, force='mesh')
    
    # Extract vertices and faces
    vertices = mesh.vertices
    faces = mesh.faces
       
    centered_vertices = vertices - np.mean(vertices, axis=0)
 
    max_abs = np.max(np.abs(centered_vertices))
    vertices_normalized = centered_vertices / (max_abs / 0.95)  
     
    
    # Sort vertices in specified order: y, x, z
    vertices_sorted_indices = np.lexsort((vertices_normalized[:, 1], vertices_normalized[:, 0], vertices_normalized[:, 2]))
    vertices_normalized_sorted = vertices_normalized[vertices_sorted_indices]
    
    # Convert indices to tuples for creating Look-Up Table (LUT)
    tuples_sorted_indices = [tuple([index]) for index in vertices_sorted_indices.tolist()]
    
    # Create Look-Up Table (LUT)
    lut = {old_index[0]: new_index for new_index, old_index in enumerate(tuples_sorted_indices)}
    
    # Reindex faces using LUT
    faces_reindexed = np.vectorize(lut.get, otypes=[int])(faces) 
    # Sort faces based on their lowest vertex index
    faces_sorted = faces_reindexed[np.lexsort(faces_reindexed.T)]
    
    #print(f"{file_path} vertices {len(vertices)} faces {len(faces)}")
    
    return vertices_normalized_sorted, faces_sorted 

def augment_mesh_scalar(vertices, scale_factor):
    # Apply a scalar factor to XYZ coordinates
    transformed_vertices = vertices * scale_factor
    return transformed_vertices

def generate_scale_factors(num_examples, lower_limit=0.75, upper_limit=1.25): 
    scale_factors = np.random.uniform(lower_limit, upper_limit, size=num_examples)
    return scale_factors

def jitter_mesh(vertices, jitter_factor=0.01): 
    offsets = np.random.uniform(-jitter_factor, jitter_factor, size=vertices.shape)
 
    jittered_vertices = vertices + offsets 
    return jittered_vertices 

def augment_mesh(vertices, scale_factor):
    #vertices = jitter_mesh(vertices)
    transformed_vertices = vertices * scale_factor
    
    return transformed_vertices
 

def load_models(directory, num_examples, variations):
    obj_datas = []  
    
    print(f"num_examples: {num_examples}")
    for filename in os.listdir(directory):  
        if (filename.endswith(".obj") or  filename.endswith(".glb") or  filename.endswith(".off")):
            file_path = os.path.join(directory, filename)

            scale_factors = generate_scale_factors(variations, 0.7, 0.9) 
            vertices, faces = get_3d_data(file_path) 

            for scale_factor in scale_factors: 
                aug_vertices = augment_mesh(vertices.copy(), scale_factor) 
                
                for _ in range(num_examples):
                    obj_data = {"vertices": aug_vertices.tolist(), "faces":  faces.tolist(), "texts": filename[:-4]}
                    obj_datas.append(obj_data)   
    return obj_datas
  

def load_json(file,num_examples):
    obj_datas = []
    with open(file, "r") as json_file:
        loaded_data = json.load(json_file) 
        for item in loaded_data:
            for _ in range(num_examples):
                obj_data = {"vertices": torch.tensor(item["vertices"], dtype=torch.float), "faces":  torch.tensor(item["faces"], dtype=torch.long),"texts": item["texts"] } 
                obj_datas.append(obj_data)
    return obj_datas
  
                        
         

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
import torch
from torch.utils.data import Dataset, DataLoader 
from tqdm import tqdm
import numpy as np
import torch
 
class MeshDataset(Dataset): 
    
    def __init__(self, data): 
        self.data = data
        print(f"Got {len(data)} data")

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx): 
        return self.data[idx]
    
    def embed_texts(self,transformer): 
        unique_texts = set(item['texts'] for item in self.data)
 
        text_embeddings = transformer.embed_texts(list(unique_texts))
        print(f"Got text_embeddings: {len(text_embeddings)}") 
        text_embedding_dict = dict(zip(unique_texts, text_embeddings))
 
        for item in self.data:
            text_value = item['texts']
            item['text_embeds'] = text_embedding_dict.get(text_value, None)
            del item['texts']
 
        
    def sample_obj(self):
        all_vertices = []
        all_faces = []
        vertex_offset = 0 


        translation_distance = 0.5  # Adjust as needed 
        vertex_offset = len(all_vertices)
        
        for r, faces_coordinates in enumerate(self.data):    
            if r > 30:
                break
            for vertex in faces_coordinates["vertices"]: 
                all_vertices.append(f"v {vertex[0]+translation_distance * (r / 0.2 - 1)} {vertex[1]} {vertex[2]}\n")
                #all_vertices.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n")

            for face in faces_coordinates["faces"]:
                all_faces.append(f"f {face[0]+1+vertex_offset} {face[1]+1+vertex_offset} {face[2]+1+vertex_offset}\n") 
                try:
                    all_vertices[face[0]+vertex_offset]
                    all_vertices[face[1]+vertex_offset]
                    all_vertices[face[2]+vertex_offset]
                except Exception  as e :
                    print(e)
                    print(face[0]+vertex_offset)
                    print(face[1]+vertex_offset)
                    print(face[2]+vertex_offset)
                    print(len(all_vertices))
                
            vertex_offset = len(all_vertices) 


        obj_file_content = "".join(all_vertices) + "".join(all_faces)

        # Save to a single file
        obj_file_path = "./combined_3d_models.obj"
        with open(obj_file_path, "w") as file:
            file.write(obj_file_content)

        print(obj_file_path)
         

In [3]:
import json
tables = load_models(r"toy",4,5)  
with open("toy/data.json", "w") as json_file:
   json.dump(tables, json_file)


tables = load_json("toy/data.json",2)
dataset = MeshDataset(tables) 

unique_values = set(item["texts"] for item in dataset.data)

print(len(unique_values))  
print(unique_values)


num_examples: 4
Got 160 data
4
{'cylinder', 'icosphere', 'cone', 'cube'}


In [4]:
autoencoder = MeshAutoencoder( 
    num_discrete_coors = 128  , 
) 
total_params = sum(p.numel() for p in autoencoder.encoders.parameters())
print(f"encoders Total parameters: {total_params}")
total_params = sum(p.numel() for p in autoencoder.decoders.parameters())
print(f"decoders Total parameters: {total_params}")  


encoders Total parameters: 661376
decoders Total parameters: 6676480


In [5]:
total_params = sum(p.numel() for p in autoencoder.encoders.parameters())
print(f"Total parameters: {total_params}")
print(autoencoder.encoders)

autoencoder_trainer = MeshAutoencoderTrainer(model =autoencoder,learning_rate = 1e-3, 
                                             warmup_steps = 10,
                                             dataset = dataset,   
                                             num_train_steps=100,
                                             batch_size=16,
                                             grad_accum_every=1)

loss = autoencoder_trainer.train(40,stop_at_loss = 0.25)   
# autoencoder_trainer = MeshAutoencoderTrainer(model =autoencoder,learning_rate = 1e-4, 
#                                              warmup_steps = 10,
#                                              dataset = dataset,
#                                              checkpoint_every_epoch = 20,  
#                                              num_train_steps=100,
#                                              batch_size=16,
#                                              grad_accum_every=1)

# loss = autoencoder_trainer.train(180,stop_at_loss = 0.25)   
autoencoder_trainer.save(f'toy_output/mesh-encoder_2_loss_{loss:.3f}.pt') 

Total parameters: 661376
ModuleList(
  (0): SAGEConv(64, 128, aggr=mean)
  (1): SAGEConv(128, 256, aggr=mean)
  (2): SAGEConv(256, 256, aggr=mean)
  (3): SAGEConv(256, 576, aggr=mean)
)


Epoch 1/40: 100%|██████████| 10/10 [00:10<00:00,  1.04s/it, loss=4.54]


Epoch 1 average loss: 4.801162958145142


Epoch 2/40: 100%|██████████| 10/10 [00:01<00:00,  6.00it/s, loss=3.96]


Epoch 2 average loss: 4.1680371284484865


Epoch 3/40: 100%|██████████| 10/10 [00:01<00:00,  6.53it/s, loss=3.97]


Epoch 3 average loss: 3.7970282316207884


Epoch 4/40: 100%|██████████| 10/10 [00:01<00:00,  6.14it/s, loss=3.37]


Epoch 4 average loss: 3.4966192960739138           avg loss speed: 0.7587901433308915


Epoch 5/40: 100%|██████████| 10/10 [00:01<00:00,  6.10it/s, loss=3.21]


Epoch 5 average loss: 3.2951470613479614           avg loss speed: 0.525414490699768


Epoch 6/40: 100%|██████████| 10/10 [00:01<00:00,  5.91it/s, loss=3.09]


Epoch 6 average loss: 3.130061984062195           avg loss speed: 0.39953621228535985


Epoch 7/40: 100%|██████████| 10/10 [00:01<00:00,  6.20it/s, loss=3.35]


Epoch 7 average loss: 2.9931079864501955           avg loss speed: 0.3141681273778274


Epoch 8/40: 100%|██████████| 10/10 [00:01<00:00,  6.23it/s, loss=2.99]


Epoch 8 average loss: 2.8937754154205324           avg loss speed: 0.2456635951995847


Epoch 9/40: 100%|██████████| 10/10 [00:01<00:00,  5.61it/s, loss=2.63]


Epoch 9 average loss: 2.77081093788147           avg loss speed: 0.2348375240961711


Epoch 10/40: 100%|██████████| 10/10 [00:01<00:00,  6.18it/s, loss=2.58]


Epoch 10 average loss: 2.6357030391693117           avg loss speed: 0.2501950740814207


Epoch 11/40: 100%|██████████| 10/10 [00:01<00:00,  6.04it/s, loss=3.06]


Epoch 11 average loss: 2.5768152236938477           avg loss speed: 0.18994790712992327 epochs left: 11.99


Epoch 12/40: 100%|██████████| 10/10 [00:01<00:00,  5.99it/s, loss=2.46]


Epoch 12 average loss: 2.717967319488525           avg loss speed: -0.05685758590698198


Epoch 13/40: 100%|██████████| 10/10 [00:01<00:00,  6.03it/s, loss=2.46]


Epoch 13 average loss: 2.4359117984771728           avg loss speed: 0.20758339564005546


Epoch 14/40: 100%|██████████| 10/10 [00:01<00:00,  6.09it/s, loss=2.15]


Epoch 14 average loss: 2.2725124597549438           avg loss speed: 0.3043856541315719


Epoch 15/40: 100%|██████████| 10/10 [00:01<00:00,  6.19it/s, loss=2.07]


Epoch 15 average loss: 2.1646111369132996           avg loss speed: 0.3108527223269144


Epoch 16/40: 100%|██████████| 10/10 [00:01<00:00,  6.30it/s, loss=2.05]


Epoch 16 average loss: 2.050452184677124           avg loss speed: 0.24055961370468149


Epoch 17/40: 100%|██████████| 10/10 [00:01<00:00,  5.96it/s, loss=1.94]


Epoch 17 average loss: 1.937397539615631           avg loss speed: 0.22512772083282462


Epoch 18/40: 100%|██████████| 10/10 [00:01<00:00,  6.19it/s, loss=1.89]


Epoch 18 average loss: 1.90082848072052           avg loss speed: 0.1499918063481651 epochs left: 10.67


Epoch 19/40: 100%|██████████| 10/10 [00:01<00:00,  6.32it/s, loss=1.65]


Epoch 19 average loss: 1.837030017375946           avg loss speed: 0.12586271762847923 epochs left: 12.21


Epoch 20/40: 100%|██████████| 10/10 [00:01<00:00,  6.20it/s, loss=1.66]


Epoch 20 average loss: 1.7137021541595459           avg loss speed: 0.1780498584111534 epochs left: 7.94


Epoch 21/40: 100%|██████████| 10/10 [00:01<00:00,  6.31it/s, loss=1.58]


Epoch 21 average loss: 1.6149927496910095           avg loss speed: 0.20219413439432765


Epoch 22/40: 100%|██████████| 10/10 [00:01<00:00,  6.36it/s, loss=2.01]


Epoch 22 average loss: 1.6269951105117797           avg loss speed: 0.09491319656372066 epochs left: 13.98


Epoch 23/40: 100%|██████████| 10/10 [00:01<00:00,  6.05it/s, loss=1.61]


Epoch 23 average loss: 1.7283597230911254           avg loss speed: -0.07646305163701372


Epoch 24/40: 100%|██████████| 10/10 [00:01<00:00,  6.19it/s, loss=1.52]


Epoch 24 average loss: 1.5252786517143249           avg loss speed: 0.13150387605031333 epochs left: 9.32


Epoch 25/40: 100%|██████████| 10/10 [00:01<00:00,  6.15it/s, loss=1.37]


Epoch 25 average loss: 1.4678341388702392           avg loss speed: 0.1590436895688374 epochs left: 7.34


Epoch 26/40: 100%|██████████| 10/10 [00:01<00:00,  6.20it/s, loss=1.35]


Epoch 26 average loss: 1.3803503155708312           avg loss speed: 0.19347385565439845 epochs left: 5.58


Epoch 27/40: 100%|██████████| 10/10 [00:01<00:00,  6.15it/s, loss=1.21]


Epoch 27 average loss: 1.2862542033195496           avg loss speed: 0.17156683206558188 epochs left: 5.75


Epoch 28/40: 100%|██████████| 10/10 [00:01<00:00,  6.41it/s, loss=1.18]


Epoch 28 average loss: 1.197661292552948           avg loss speed: 0.18048492670059213 epochs left: 4.97


Epoch 29/40: 100%|██████████| 10/10 [00:01<00:00,  6.44it/s, loss=1.29]


Epoch 29 average loss: 1.154736590385437           avg loss speed: 0.13335201342900604 epochs left: 6.41


Epoch 30/40: 100%|██████████| 10/10 [00:01<00:00,  6.35it/s, loss=0.985]


Epoch 30 average loss: 1.0822443425655366           avg loss speed: 0.1306396861871082 epochs left: 5.99


Epoch 31/40: 100%|██████████| 10/10 [00:01<00:00,  6.21it/s, loss=1.02]


Epoch 31 average loss: 1.0353250622749328           avg loss speed: 0.10955567955970769 epochs left: 6.71


Epoch 32/40: 100%|██████████| 10/10 [00:01<00:00,  6.19it/s, loss=1.04]


Epoch 32 average loss: 0.9589943587779999           avg loss speed: 0.13177430629730225 epochs left: 5.00


Epoch 33/40: 100%|██████████| 10/10 [00:01<00:00,  6.06it/s, loss=0.856]


Epoch 33 average loss: 0.906355905532837           avg loss speed: 0.11916534900665265 epochs left: 5.09


Epoch 34/40: 100%|██████████| 10/10 [00:01<00:00,  6.26it/s, loss=0.929]


Epoch 34 average loss: 0.8754875779151916           avg loss speed: 0.09140419761339824 epochs left: 6.30


Epoch 35/40: 100%|██████████| 10/10 [00:01<00:00,  6.30it/s, loss=0.721]


Epoch 35 average loss: 0.8517052233219147           avg loss speed: 0.061907390753428104 epochs left: 8.91


Epoch 36/40: 100%|██████████| 10/10 [00:01<00:00,  6.35it/s, loss=0.811]


Epoch 36 average loss: 0.8154086709022522           avg loss speed: 0.06244089802106223 epochs left: 8.25


Epoch 37/40: 100%|██████████| 10/10 [00:01<00:00,  6.24it/s, loss=0.765]


Epoch 37 average loss: 0.7823117136955261           avg loss speed: 0.06522211035092673 epochs left: 7.39


Epoch 38/40: 100%|██████████| 10/10 [00:01<00:00,  6.15it/s, loss=0.774]


Epoch 38 average loss: 0.7334072530269623           avg loss speed: 0.0830679496129354 epochs left: 5.22


Epoch 39/40: 100%|██████████| 10/10 [00:01<00:00,  6.29it/s, loss=0.755]


Epoch 39 average loss: 0.7237970709800721           avg loss speed: 0.05324547489484144 epochs left: 7.96


Epoch 40/40: 100%|██████████| 10/10 [00:01<00:00,  6.14it/s, loss=0.719]


Epoch 40 average loss: 0.712163758277893           avg loss speed: 0.03434158762296036 epochs left: 12.00
Training complete


In [6]:
max_length =  max(len(d["faces"]) for d in dataset if "faces" in d) 
max_seq =  max_length * 6  
print(max_length)
print(max_seq)
transformer = MeshTransformer(
    autoencoder,
    dim = 768,
    max_seq_len = max_seq,
    condition_on_text = True
)
total_params = sum(p.numel() for p in transformer.parameters())
print(f"Total parameters: {total_params}") 

124
744
Total parameters: 226891215


In [7]:
 
trainer = MeshTransformerTrainer(model = transformer,warmup_steps = 10,grad_accum_every=1,num_train_steps=100,   dataset = dataset,
                                 learning_rate = 1e-1, batch_size=2)
trainer.train(20,stop_at_loss = 0.00009)   

 
# trainer = MeshTransformerTrainer(model = transformer,warmup_steps = 10,grad_accum_every=1,num_train_steps=100,  dataset = dataset,
#                                  learning_rate = 1e-2, batch_size=2)
# trainer.train(80,stop_at_loss = 0.00009)    

# trainer = MeshTransformerTrainer(model = transformer,warmup_steps = 10,grad_accum_every=1,num_train_steps=100,  dataset = dataset,
#                                  learning_rate = 1e-4, batch_size=2)
# trainer.train(80,stop_at_loss = 0.00009)   

trainer.save(f'toy_output/mesh-transformer_2_{loss:.3f}.pt')    

Epoch 1/20: 100%|██████████| 80/80 [00:36<00:00,  2.22it/s, loss=1.58] 


Epoch 1 average loss: 2.6676975294947622


Epoch 2/20: 100%|██████████| 80/80 [00:35<00:00,  2.27it/s, loss=0.598]


Epoch 2 average loss: 0.8363735556602478


Epoch 3/20: 100%|██████████| 80/80 [00:35<00:00,  2.25it/s, loss=0.552]


Epoch 3 average loss: 0.5185843899846077


Epoch 4/20: 100%|██████████| 80/80 [00:35<00:00,  2.28it/s, loss=0.114] 


Epoch 4 average loss: 0.29734265329316256           avg loss speed: 1.0435425050867102


Epoch 5/20: 100%|██████████| 80/80 [00:35<00:00,  2.25it/s, loss=0.142] 


Epoch 5 average loss: 0.14500402342528104           avg loss speed: 0.4057628428873916


Epoch 6/20: 100%|██████████| 80/80 [00:35<00:00,  2.26it/s, loss=0.052] 


Epoch 6 average loss: 0.055219741095788775           avg loss speed: 0.265090614471895


Epoch 7/20: 100%|██████████| 80/80 [00:36<00:00,  2.20it/s, loss=0.0145] 


Epoch 7 average loss: 0.0258186811581254           avg loss speed: 0.1400367914466187 epochs left: 0.18


Epoch 8/20: 100%|██████████| 80/80 [00:36<00:00,  2.22it/s, loss=0.0385] 


Epoch 8 average loss: 0.053307528025470674           avg loss speed: 0.0220399538675944 epochs left: 2.37


Epoch 9/20: 100%|██████████| 80/80 [00:36<00:00,  2.18it/s, loss=0.0916]


Epoch 9 average loss: 0.05474392602918669           avg loss speed: -0.00996194260272508


Epoch 10/20: 100%|██████████| 80/80 [00:37<00:00,  2.14it/s, loss=0.0187] 


Epoch 10 average loss: 0.041762136621400714           avg loss speed: 0.0028612417828602077 epochs left: 14.25


Epoch 11/20: 100%|██████████| 80/80 [00:40<00:00,  1.99it/s, loss=0.0666] 


Epoch 11 average loss: 0.01772295873379335           avg loss speed: 0.032214904824892684 epochs left: 0.52


Epoch 12/20: 100%|██████████| 80/80 [00:40<00:00,  1.97it/s, loss=0.218]  


Epoch 12 average loss: 0.026159881066996606           avg loss speed: 0.011916459394463647 epochs left: 2.11


Epoch 13/20: 100%|██████████| 80/80 [00:42<00:00,  1.87it/s, loss=0.0336]


Epoch 13 average loss: 0.05687048444524408           avg loss speed: -0.02832215897118052


Epoch 14/20: 100%|██████████| 80/80 [00:43<00:00,  1.86it/s, loss=0.0117] 


Epoch 14 average loss: 0.016708082618424668           avg loss speed: 0.016876358796920007 epochs left: 0.93


Epoch 15/20: 100%|██████████| 80/80 [00:43<00:00,  1.83it/s, loss=0.0144] 


Epoch 15 average loss: 0.01217394964187406           avg loss speed: 0.021072199735014393 epochs left: 0.53


Epoch 16/20: 100%|██████████| 80/80 [00:44<00:00,  1.78it/s, loss=0.0165] 


Epoch 16 average loss: 0.011738798511214555           avg loss speed: 0.01684537372396638 epochs left: 0.64


Epoch 17/20: 100%|██████████| 80/80 [00:46<00:00,  1.73it/s, loss=0.0173] 


Epoch 17 average loss: 0.014887878054287285           avg loss speed: -0.0013476011304495248


Epoch 18/20: 100%|██████████| 80/80 [00:47<00:00,  1.70it/s, loss=0.274] 


Epoch 18 average loss: 0.08145265569910407           avg loss speed: -0.06851911362997877


Epoch 19/20: 100%|██████████| 80/80 [00:46<00:00,  1.72it/s, loss=0.00861]


Epoch 19 average loss: 0.027562913956353442           avg loss speed: 0.008463530131848529 epochs left: 3.14


Epoch 20/20: 100%|██████████| 80/80 [00:48<00:00,  1.66it/s, loss=0.00905]


Epoch 20 average loss: 0.01292018317617476           avg loss speed: 0.028380966060406836 epochs left: 0.42
Training complete


In [8]:
unique_values = set(item["texts"] for item in dataset.data)
print(len(unique_values))  
coords = []
for text in unique_values: 
    print(f"doing {text}")
    faces_coordinates = transformer.generate(texts = [text]) 
    coords.append(faces_coordinates)
    tensor_data = faces_coordinates[0].cpu()
    
    numpy_data = tensor_data.numpy().reshape(-1, 3)
    
    obj_file_content = ""
    
    for vertex in numpy_data:
        obj_file_content += f"v {vertex[0]} {vertex[1]} {vertex[2]}\n"

    for i in range(1, len(numpy_data), 3):
        obj_file_content += f"f {i} {i + 1} {i + 2}\n"

    # Save to a file
    obj_file_path = f'toy_output/3d_output_{text}.obj'
    with open(obj_file_path, "w") as file:
        file.write(obj_file_content)

    print(obj_file_path) 
    
    
all_vertices = []
all_faces = []
vertex_offset = 0
 
translation_distance = 0.3  

for r, faces_coordinates in enumerate(coords): 
    tensor_data = faces_coordinates[0].cpu()

    numpy_data = tensor_data.numpy().reshape(-1, 3)

    # Translate the model to avoid overlapping
    numpy_data[:, 0] += translation_distance * (r / 0.2 - 1)  # Adjust X coordinate

    # Accumulate vertices
    for vertex in numpy_data:
        all_vertices.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n")

    # Accumulate faces with adjusted indices
    for i in range(1, len(numpy_data), 3):
        all_faces.append(f"f {i + vertex_offset} {i + 1 + vertex_offset} {i + 2 + vertex_offset}\n")

    # Update the vertex offset for the next model
    vertex_offset += len(numpy_data)

# Combine vertices and faces
obj_file_content = "".join(all_vertices) + "".join(all_faces)

# Save to a single file
obj_file_path = f"toy_output/3d_models_all.obj"
with open(obj_file_path, "w") as file:
    file.write(obj_file_content)

print(obj_file_path)


4
doing cylinder


100%|██████████| 744/744 [00:13<00:00, 56.81it/s]


toy_output/3d_output_cylinder.obj
doing icosphere


 50%|█████     | 372/744 [00:07<00:07, 47.92it/s]


toy_output/3d_output_icosphere.obj
doing cone


 10%|▉         | 72/744 [00:02<00:21, 30.56it/s]


toy_output/3d_output_cone.obj
doing cube


 50%|█████     | 372/744 [00:10<00:10, 34.76it/s]


toy_output/3d_output_cube.obj
toy_output/3d_models_all.obj


In [None]:

coords_all = []
for text in set(item["texts"] for item in dataset.data): 
    print(f"Doing {text}")
    coords = []
    for r in np.arange(0, 1.0, 0.1):
        faces_coordinates = transformer.generate(temperature=r, texts = [text]) 
        coords.append(faces_coordinates)
    coords_all.append(coords)
    
    all_vertices = []
    all_faces = []
    vertex_offset = 0

    # Translation distance for each model
    translation_distance = 0.3  # Adjust as needed

    for r, faces_coordinates in enumerate(coords): 
        tensor_data = faces_coordinates[0].cpu()

        numpy_data = tensor_data.numpy().reshape(-1, 3)

        # Translate the model to avoid overlapping
        numpy_data[:, 0] += translation_distance * (r / 0.2 - 1)  # Adjust X coordinate

        # Accumulate vertices
        for vertex in numpy_data:
            all_vertices.append(f"v {vertex[0]} {vertex[1]} {vertex[2]}\n")

        # Accumulate faces with adjusted indices
        for i in range(1, len(numpy_data), 3):
            all_faces.append(f"f {i + vertex_offset} {i + 1 + vertex_offset} {i + 2 + vertex_offset}\n")

        # Update the vertex offset for the next model
        vertex_offset += len(numpy_data)

    # Combine vertices and faces
    obj_file_content = "".join(all_vertices) + "".join(all_faces)

    # Save to a single file
    obj_file_path = f"./results/3d_models_{text}_temps.obj"
    with open(obj_file_path, "w") as file:
        file.write(obj_file_content)

    print(obj_file_path)


In [None]:

def loadModels():
    autoencoder = MeshAutoencoder(
        dim = 576,
        encoder_depth = 6,
        decoder_depth = 6,
        num_discrete_coors = 128  ,
        local_attn_depth =0, 
        
    )
    autoencoder_trainer = MeshAutoencoderTrainer(model = autoencoder,
                                    learning_rate = 1e-1, 
                                                checkpoint_every_epoch= 5,
                                                warmup_steps = 10,
                                                dataset = dataset,  
                                                num_train_steps=100,
                                                batch_size=2,
                                                grad_accum_every=1)

    autoencoder_trainer.load(r"mesh-encoder_last.pt")
    encoder = autoencoder_trainer.model
    max_length =  max(len(d["faces"]) for d in dataset if "faces" in d) 
    max_seq =  max_length * 6  
    
    transformer = MeshTransformer(
        autoencoder,
        dim = 768,
        max_seq_len = max_seq,
        condition_on_text = True)
     
    trainer = MeshTransformerTrainer(model = transformer,warmup_steps = 10,grad_accum_every=1,num_train_steps=100, checkpoint_folder = r"F:\MachineLearning\Mesh\MeshGPT\checkpoints" , dataset = dataset,
                                    learning_rate = 1e-3, batch_size=2) 
    trainer.load(r"mesh-transformer.pt")
    transformer = trainer.model
    return transformer, encoder

#transformer, autoencoder =  loadModels() 