Skip to content

Commit

Permalink
1. Bug fix about weights release during training
Browse files Browse the repository at this point in the history
2. Upgrade external project ManagedCUDA dependency to CUDA 12
3. Add more debug information for data shuffling.
  • Loading branch information
zhongkaifu committed Feb 7, 2023
1 parent 834b9b5 commit e3de9c6
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 10 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,7 @@ Temporary Items
/ExternalProjects/managedCuda/CudaBlas/bin/Release/net7.0
/ExternalProjects/managedCuda/CudaRand/bin/Release/net7.0
/ExternalProjects/managedCuda/NVRTC/bin/Release/net7.0
/ExternalProjects/managedCuda/CudaBlas/bin/Debug/net7.0
/ExternalProjects/managedCuda/CudaRand/bin/Debug/net7.0
/ExternalProjects/managedCuda/NVRTC/bin/Debug/net7.0
/ExternalProjects/managedCuda/ManagedCUDA/bin/Debug/net7.0
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ namespace ManagedCuda.CudaBlas
public static class CudaBlasNativeMethods
{
//32bit is no more supported, only 64 bit...
internal const string CUBLAS_API_DLL_NAME = "cublas64_11";
internal const string CUBLAS_API_DLL_NAME = "cublas64_12";


#if (NETCOREAPP)
Expand Down
2 changes: 1 addition & 1 deletion ExternalProjects/managedCuda/NVRTC/NVRTCNativeMethods.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace ManagedCuda.NVRTC
/// <summary/>
public static class NVRTCNativeMethods
{
internal const string NVRTC_API_DLL_NAME = "nvrtc64_112_0";
internal const string NVRTC_API_DLL_NAME = "nvrtc64_120_0";

#if (NETCOREAPP)
internal const string NVRTC_API_DLL_NAME_LINUX = "nvrtc";
Expand Down
22 changes: 14 additions & 8 deletions Seq2SeqSharp/Corpus/ParallelCorpus.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ public class ParallelCorpus<T> : IParallelCorpus<T> where T : ISntPairBatch, new
private TooLongSequence m_tooLongSequence = TooLongSequence.Ignore;

private string m_binaryDataSetFilePath = "";

private int m_batchNumInTotal = 0;

public (List<Dictionary<string, int>>, List<Dictionary<string, int>>) CountTokenFreqs()
{
Expand Down Expand Up @@ -320,8 +320,6 @@ public void PrepareDataSet()
try
{
(var length2offsets, var length2counts, string tmpDataSetFilePath) = BuildIndex();

int batchNum = 0;
Logger.WriteLine($"Start to sort and shuffle data set by length.");

m_binaryDataSetFilePath = tmpDataSetFilePath + ".sorted";
Expand Down Expand Up @@ -367,10 +365,10 @@ public void PrepareDataSet()
bw.Write(String.Join("\n", srcLines));
bw.Write(String.Join("\n", tgtLines));

batchNum++;
if (batchNum % 10000 == 0)
m_batchNumInTotal++;
if (m_batchNumInTotal % 10000 == 0)
{
Logger.WriteLine($"Batch '{batchNum}' has been processed.");
Logger.WriteLine($"Batch '{m_batchNumInTotal}' has been processed.");
}


Expand All @@ -387,7 +385,7 @@ public void PrepareDataSet()

File.Delete(tmpDataSetFilePath);

Logger.WriteLine($"Finished to sort and shuffle data set by length.");
Logger.WriteLine($"Finished to sort and shuffle data set by length. Total batch size = '{m_batchNumInTotal}'");
}
catch (Exception err)
{
Expand All @@ -398,6 +396,8 @@ public void PrepareDataSet()
public IEnumerator<T> GetEnumerator()
{
PrepareDataSet();
int batchIdx = 0;
int currentBatchPercent = 0;

using (MemoryMappedFile mmf = MemoryMappedFile.CreateFromFile(m_binaryDataSetFilePath))
using (MemoryMappedViewStream mms = mmf.CreateViewStream())
Expand All @@ -416,13 +416,19 @@ public IEnumerator<T> GetEnumerator()

string[] srcLines = br.ReadString().Split("\n");
string[] tgtLines = br.ReadString().Split("\n");

batchIdx++;

for (int i = 0; i < sizeInBatch; i++)
{
var srcLine = srcLines[i];
var tgtLine = tgtLines[i];

if ((100 * batchIdx / m_batchNumInTotal) > currentBatchPercent)
{
Logger.WriteLine($"Processing batch '{batchIdx}/{m_batchNumInTotal}'. The '{i}th' record in this batch is: Source = '{srcLine}' Target = '{tgtLine}'");
currentBatchPercent++;
}

SntPair sntPair = new SntPair(srcLine, tgtLine);
outputs.Add(sntPair);
}
Expand Down
5 changes: 5 additions & 0 deletions Seq2SeqSharp/Tools/ComputeGraphTensor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,8 @@ void backward()
res.Dispose();
}
m_backprop.Add(backward);

res.UnbindFromComputeGraph();
}

return res;
Expand Down Expand Up @@ -241,6 +243,8 @@ void backward()
res.Dispose();
}
m_backprop.Add(backward);

res.UnbindFromComputeGraph();
}

return res;
Expand Down Expand Up @@ -2011,6 +2015,7 @@ void backward()
}
m_backprop.Add(backward);

res.UnbindFromComputeGraph();
srcT.UnbindFromComputeGraph();
alphaT.UnbindFromComputeGraph();
betaT.UnbindFromComputeGraph();
Expand Down

0 comments on commit e3de9c6

Please sign in to comment.